Skip to content

Commit f587cc8

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
[draft] Optimized API for Daglish read-only traversals (traversals that do not
re-construct objects). PiperOrigin-RevId: 614885965
1 parent 2a17618 commit f587cc8

File tree

1 file changed

+50
-2
lines changed

1 file changed

+50
-2
lines changed

fiddle/_src/daglish.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ class NamedTupleType:
143143
# This has the same type as Path, but different semantic meaning.
144144
PathElements = Tuple[PathElement, ...]
145145
PathElementsFn = Callable[[Any], PathElements]
146+
_ValueAndPath = Tuple[Any, PathElement]
147+
OptimizedFlattenFn = Callable[[Any], Iterable[_ValueAndPath]]
146148
FlattenFn = Callable[[Any], Tuple[Tuple[Any, ...], Any]]
147149
UnflattenFn = Callable[[Iterable[Any], Any], Any]
148150

@@ -155,6 +157,7 @@ class NodeTraverser:
155157
flatten: FlattenFn
156158
unflatten: UnflattenFn
157159
path_elements: PathElementsFn
160+
flatten_with_paths: OptimizedFlattenFn | None = None
158161

159162

160163
class NodeTraverserRegistry:
@@ -185,6 +188,7 @@ def register_node_traverser(
185188
flatten_fn: FlattenFn,
186189
unflatten_fn: UnflattenFn,
187190
path_elements_fn: PathElementsFn,
191+
flatten_with_paths_fn: OptimizedFlattenFn | None = None,
188192
) -> None:
189193
"""Registers a node traverser for `node_type`.
190194
@@ -202,6 +206,9 @@ def register_node_traverser(
202206
flattened values returned by `flatten_fn`. This should accept an
203207
instance of `node_type`, and return a sequence of `PathElement`s aligned
204208
with the values returned by `flatten_fn`.
209+
flatten_with_paths_fn: A version of `flatten_fn` that returns an iterable
210+
of `(value, path)` pairs, where `value` is a child value and `path` is a
211+
`Path` to the value.
205212
"""
206213
if not isinstance(node_type, type):
207214
raise TypeError(f"`node_type` ({node_type}) must be a type.")
@@ -212,6 +219,7 @@ def register_node_traverser(
212219
flatten=flatten_fn,
213220
unflatten=unflatten_fn,
214221
path_elements=path_elements_fn,
222+
flatten_with_paths=flatten_with_paths_fn,
215223
)
216224

217225
def find_node_traverser(
@@ -282,7 +290,9 @@ def unflatten_defaultdict(values, metadata):
282290
tuple,
283291
flatten_fn=lambda x: (x, None),
284292
unflatten_fn=lambda x, _: tuple(x),
285-
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))))
293+
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))),
294+
flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)),
295+
)
286296

287297
register_node_traverser(
288298
NamedTupleType,
@@ -294,7 +304,9 @@ def unflatten_defaultdict(values, metadata):
294304
list,
295305
flatten_fn=lambda x: (tuple(x), None),
296306
unflatten_fn=lambda x, _: list(x),
297-
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))))
307+
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))),
308+
flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)),
309+
)
298310

299311

300312
def is_prefix(prefix_path: Path, containing_path: Path):
@@ -610,6 +622,42 @@ def flattened_map_children(self, value: Any) -> SubTraversalResult:
610622
raise ValueError("Please handle non-traversable values yourself.")
611623
return self._flattened_map_children(value, node_traverser)
612624

625+
def fast_map_child_values(
626+
self, value: Any, ignore_leaves: bool = False
627+
) -> Iterable[Any]:
628+
"""Maps over children for traversable values, but doesn't unflatten results.
629+
630+
This method only returns result values, so use it in place of
631+
`state.flattened_map_children(value).values`.
632+
633+
Args:
634+
value: Value to map over.
635+
ignore_leaves: If True, then this function will return an empty iterable
636+
if `value` is not traversable. Otherwise, it will raise a ValueError.
637+
638+
Yields:
639+
Sub-traversal results, the same type as returned by your _traverse
640+
function.
641+
642+
Raises:
643+
ValueError: If `value` is not traversable. Please test beforehand by
644+
calling `state.is_traversable()`.
645+
"""
646+
node_traverser = self.traversal.find_node_traverser(type(value))
647+
if node_traverser is None:
648+
if ignore_leaves:
649+
return
650+
else:
651+
raise ValueError("Please handle non-traversable values yourself.")
652+
if node_traverser.flatten_with_paths is not None:
653+
for value, path in node_traverser.flatten_with_paths(value):
654+
yield self.call(value, path)
655+
else:
656+
sub_values, unused_meta = node_traverser.flatten(value)
657+
path_elements = node_traverser.path_elements(value)
658+
for sub_value, path_element in zip(sub_values, path_elements):
659+
yield self.call(sub_value, path_element)
660+
613661
def call(self, value, *additional_path: PathElement):
614662
"""Low-level function to execute a sub-traversal.
615663

0 commit comments

Comments
 (0)