@@ -143,6 +143,8 @@ class NamedTupleType:
143
143
# This has the same type as Path, but different semantic meaning.
144
144
PathElements = Tuple [PathElement , ...]
145
145
PathElementsFn = Callable [[Any ], PathElements ]
146
+ _ValueAndPath = Tuple [Any , PathElement ]
147
+ OptimizedFlattenFn = Callable [[Any ], Iterable [_ValueAndPath ]]
146
148
FlattenFn = Callable [[Any ], Tuple [Tuple [Any , ...], Any ]]
147
149
UnflattenFn = Callable [[Iterable [Any ], Any ], Any ]
148
150
@@ -155,6 +157,7 @@ class NodeTraverser:
155
157
flatten : FlattenFn
156
158
unflatten : UnflattenFn
157
159
path_elements : PathElementsFn
160
+ flatten_with_paths : OptimizedFlattenFn | None = None
158
161
159
162
160
163
class NodeTraverserRegistry :
@@ -185,6 +188,7 @@ def register_node_traverser(
185
188
flatten_fn : FlattenFn ,
186
189
unflatten_fn : UnflattenFn ,
187
190
path_elements_fn : PathElementsFn ,
191
+ flatten_with_paths_fn : OptimizedFlattenFn | None = None ,
188
192
) -> None :
189
193
"""Registers a node traverser for `node_type`.
190
194
@@ -202,6 +206,9 @@ def register_node_traverser(
202
206
flattened values returned by `flatten_fn`. This should accept an
203
207
instance of `node_type`, and return a sequence of `PathElement`s aligned
204
208
with the values returned by `flatten_fn`.
209
+ flatten_with_paths_fn: A version of `flatten_fn` that returns an iterable
210
+ of `(value, path)` pairs, where `value` is a child value and `path` is a
211
+ `Path` to the value.
205
212
"""
206
213
if not isinstance (node_type , type ):
207
214
raise TypeError (f"`node_type` ({ node_type } ) must be a type." )
@@ -212,6 +219,7 @@ def register_node_traverser(
212
219
flatten = flatten_fn ,
213
220
unflatten = unflatten_fn ,
214
221
path_elements = path_elements_fn ,
222
+ flatten_with_paths = flatten_with_paths_fn ,
215
223
)
216
224
217
225
def find_node_traverser (
@@ -282,7 +290,9 @@ def unflatten_defaultdict(values, metadata):
282
290
tuple ,
283
291
flatten_fn = lambda x : (x , None ),
284
292
unflatten_fn = lambda x , _ : tuple (x ),
285
- path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))))
293
+ path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))),
294
+ flatten_with_paths_fn = lambda xs : ((x , Index (i )) for i , x in enumerate (xs )),
295
+ )
286
296
287
297
register_node_traverser (
288
298
NamedTupleType ,
@@ -294,7 +304,9 @@ def unflatten_defaultdict(values, metadata):
294
304
list ,
295
305
flatten_fn = lambda x : (tuple (x ), None ),
296
306
unflatten_fn = lambda x , _ : list (x ),
297
- path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))))
307
+ path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))),
308
+ flatten_with_paths_fn = lambda xs : ((x , Index (i )) for i , x in enumerate (xs )),
309
+ )
298
310
299
311
300
312
def is_prefix (prefix_path : Path , containing_path : Path ):
@@ -610,6 +622,42 @@ def flattened_map_children(self, value: Any) -> SubTraversalResult:
610
622
raise ValueError ("Please handle non-traversable values yourself." )
611
623
return self ._flattened_map_children (value , node_traverser )
612
624
625
+ def fast_map_child_values (
626
+ self , value : Any , ignore_leaves : bool = False
627
+ ) -> Iterable [Any ]:
628
+ """Maps over children for traversable values, but doesn't unflatten results.
629
+
630
+ This method only returns result values, so use it in place of
631
+ `state.flattened_map_children(value).values`.
632
+
633
+ Args:
634
+ value: Value to map over.
635
+ ignore_leaves: If True, then this function will return an empty iterable
636
+ if `value` is not traversable. Otherwise, it will raise a ValueError.
637
+
638
+ Yields:
639
+ Sub-traversal results, the same type as returned by your _traverse
640
+ function.
641
+
642
+ Raises:
643
+ ValueError: If `value` is not traversable. Please test beforehand by
644
+ calling `state.is_traversable()`.
645
+ """
646
+ node_traverser = self .traversal .find_node_traverser (type (value ))
647
+ if node_traverser is None :
648
+ if ignore_leaves :
649
+ return
650
+ else :
651
+ raise ValueError ("Please handle non-traversable values yourself." )
652
+ if node_traverser .flatten_with_paths is not None :
653
+ for value , path in node_traverser .flatten_with_paths (value ):
654
+ yield self .call (value , path )
655
+ else :
656
+ sub_values , unused_meta = node_traverser .flatten (value )
657
+ path_elements = node_traverser .path_elements (value )
658
+ for sub_value , path_element in zip (sub_values , path_elements ):
659
+ yield self .call (sub_value , path_element )
660
+
613
661
def call (self , value , * additional_path : PathElement ):
614
662
"""Low-level function to execute a sub-traversal.
615
663
0 commit comments