From e3f4c98422803fcb0e85450b15445b8cdb10bac8 Mon Sep 17 00:00:00 2001 From: Fiddle-Config Team Date: Mon, 11 Mar 2024 20:08:08 -0700 Subject: [PATCH] [draft] Optimized API for Daglish read-only traversals (traversals that do not re-construct objects). PiperOrigin-RevId: 614885965 --- .../auto_config/complex_to_variables.py | 3 +- fiddle/_src/codegen/legacy_codegen.py | 10 +- fiddle/_src/daglish.py | 75 +++++++++++-- fiddle/_src/daglish_test.py | 103 +++++++++++++++++- fiddle/_src/debug/grep.py | 3 +- fiddle/_src/experimental/visualize.py | 6 +- fiddle/_src/graphviz.py | 9 +- fiddle/_src/materialize.py | 4 +- fiddle/_src/partial.py | 4 +- fiddle/_src/printing.py | 12 +- fiddle/_src/selectors.py | 2 +- fiddle/_src/validation/baseline_style.py | 4 +- 12 files changed, 197 insertions(+), 38 deletions(-) diff --git a/fiddle/_src/codegen/auto_config/complex_to_variables.py b/fiddle/_src/codegen/auto_config/complex_to_variables.py index e12dc300..f29c1697 100644 --- a/fiddle/_src/codegen/auto_config/complex_to_variables.py +++ b/fiddle/_src/codegen/auto_config/complex_to_variables.py @@ -50,8 +50,7 @@ def traverse(value, state: daglish.State) -> int: elif not state.is_traversable(value): return 1 else: - sub_values = state.flattened_map_children(value) - return 1 + sum(sub_values.values) + return 1 + sum(state.yield_map_child_values(value)) return lambda x: daglish.MemoizedTraversal.run(traverse, x) > level diff --git a/fiddle/_src/codegen/legacy_codegen.py b/fiddle/_src/codegen/legacy_codegen.py index 7d326eff..11e651c7 100644 --- a/fiddle/_src/codegen/legacy_codegen.py +++ b/fiddle/_src/codegen/legacy_codegen.py @@ -72,8 +72,8 @@ def traverse(value, state: daglish.State): if isinstance(value, config_lib.Buildable): to_count[id(value)] += 1 children_by_id[id(value)] = value - if state.is_traversable(value): - state.flattened_map_children(value) + for _ in state.yield_map_child_values(value, ignore_leaves=True): + pass # Run lazy iterator. daglish.BasicTraversal.run(traverse, buildable) return [ @@ -194,10 +194,10 @@ def _configure_shared_objects( variable_name_prefix: Prefix for any variables introduced. """ - def traverse(child, state): + def traverse(child, state: daglish.State): """Generates code for a shared instance.""" - if state.is_traversable(child): - state.flattened_map_children(child) + for _ in state.yield_map_child_values(child, ignore_leaves=True): + pass # Run lazy iterator. if isinstance(child, config_lib.Buildable): # Name this better.. name = shared_manager.namespace.get_new_name( diff --git a/fiddle/_src/daglish.py b/fiddle/_src/daglish.py index 5bd30a51..06fba7e1 100644 --- a/fiddle/_src/daglish.py +++ b/fiddle/_src/daglish.py @@ -143,6 +143,8 @@ class NamedTupleType: # This has the same type as Path, but different semantic meaning. PathElements = Tuple[PathElement, ...] PathElementsFn = Callable[[Any], PathElements] +_ValueAndPath = Tuple[Any, PathElement] +OptimizedFlattenFn = Callable[[Any], Iterable[_ValueAndPath]] FlattenFn = Callable[[Any], Tuple[Tuple[Any, ...], Any]] UnflattenFn = Callable[[Iterable[Any], Any], Any] @@ -155,6 +157,7 @@ class NodeTraverser: flatten: FlattenFn unflatten: UnflattenFn path_elements: PathElementsFn + flatten_with_paths: OptimizedFlattenFn | None = None class NodeTraverserRegistry: @@ -185,6 +188,7 @@ def register_node_traverser( flatten_fn: FlattenFn, unflatten_fn: UnflattenFn, path_elements_fn: PathElementsFn, + flatten_with_paths_fn: OptimizedFlattenFn | None = None, ) -> None: """Registers a node traverser for `node_type`. @@ -202,6 +206,9 @@ def register_node_traverser( flattened values returned by `flatten_fn`. This should accept an instance of `node_type`, and return a sequence of `PathElement`s aligned with the values returned by `flatten_fn`. + flatten_with_paths_fn: A version of `flatten_fn` that returns an iterable + of `(value, path)` pairs, where `value` is a child value and `path` is a + `Path` to the value. """ if not isinstance(node_type, type): raise TypeError(f"`node_type` ({node_type}) must be a type.") @@ -212,6 +219,7 @@ def register_node_traverser( flatten=flatten_fn, unflatten=unflatten_fn, path_elements=path_elements_fn, + flatten_with_paths=flatten_with_paths_fn, ) def find_node_traverser( @@ -282,7 +290,9 @@ def unflatten_defaultdict(values, metadata): tuple, flatten_fn=lambda x: (x, None), unflatten_fn=lambda x, _: tuple(x), - path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x)))) + path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))), + flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)), +) register_node_traverser( NamedTupleType, @@ -294,7 +304,9 @@ def unflatten_defaultdict(values, metadata): list, flatten_fn=lambda x: (tuple(x), None), unflatten_fn=lambda x, _: list(x), - path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x)))) + path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))), + flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)), +) def is_prefix(prefix_path: Path, containing_path: Path): @@ -607,9 +619,57 @@ def flattened_map_children(self, value: Any) -> SubTraversalResult: """ node_traverser = self.traversal.find_node_traverser(type(value)) if node_traverser is None: - raise ValueError("Please handle non-traversable values yourself.") + raise ValueError( + f"No node traverser found for {value}, please register traverser" + " or pre-process the non-traversable values first." + ) return self._flattened_map_children(value, node_traverser) + def yield_map_child_values( + self, value: Any, ignore_leaves: bool = False + ) -> Iterable[Any]: + """Maps over children for traversable values, but doesn't unflatten results. + + This method only returns result values, so use it in place of + `state.flattened_map_children(value).values` when the result is consumed + once. If you are calling this for an "effectful" computation (including ones + that just gather results in `nonlocal` variables) then consider the + following pattern, + + for _ in state.yield_map_child_values(node, ignore_leaves=True): + pass # Run lazy iterator. + + Args: + value: Value to map over. + ignore_leaves: If True, then this function will return an empty iterable + if `value` is not traversable. Otherwise, it will raise a ValueError. + + Yields: + Sub-traversal results, the same type as returned by your _traverse + function. + + Raises: + ValueError: If `value` is not traversable and `ignore_leaves` is `False`. + Please test beforehand by calling `state.is_traversable()`. + """ + node_traverser = self.traversal.find_node_traverser(type(value)) + if node_traverser is None: + if ignore_leaves: + return + else: + raise ValueError( + f"No node traverser found for {value}, please register traverser" + " or pre-process the non-traversable values first." + ) + if node_traverser.flatten_with_paths is not None: + for value, path in node_traverser.flatten_with_paths(value): + yield self.call(value, path) + else: + sub_values, unused_meta = node_traverser.flatten(value) + path_elements = node_traverser.path_elements(value) + for sub_value, path_element in zip(sub_values, path_elements): + yield self.call(sub_value, path_element) + def call(self, value, *additional_path: PathElement): """Low-level function to execute a sub-traversal. @@ -755,8 +815,8 @@ def collect_paths_by_id( def traverse(value, state: State): if not memoizable_only or is_memoizable(value): paths_by_id.setdefault(id(value), []).append(state.current_path) - if state.is_traversable(value): - state.flattened_map_children(value) + for _ in state.yield_map_child_values(value, ignore_leaves=True): + pass # Run lazy iterator. traversal = BasicTraversal(traverse, structure, registry=registry) traverse(structure, traversal.initial_state()) @@ -794,9 +854,8 @@ def iterate( def _traverse(node, state: State): yield node, state.current_path - if state.is_traversable(node): - for sub_result in state.flattened_map_children(node).values: - yield from sub_result + for sub_result in state.yield_map_child_values(node, ignore_leaves=True): + yield from sub_result if memoized: traversal = MemoizedTraversal( diff --git a/fiddle/_src/daglish_test.py b/fiddle/_src/daglish_test.py index e0f43f81..d1a172d7 100644 --- a/fiddle/_src/daglish_test.py +++ b/fiddle/_src/daglish_test.py @@ -20,7 +20,7 @@ import enum import json import random -from typing import Any, List, NamedTuple, Optional, cast +from typing import Any, List, NamedTuple, Optional, Tuple, cast from absl.testing import absltest from absl.testing import parameterized @@ -76,6 +76,16 @@ def __call__(self, value, state: Optional[daglish.State] = None): return state.map_children(value) +@dataclasses.dataclass +class NonRecursingLoggingFunction: + values_and_paths: List[Tuple[Any, Any]] = dataclasses.field( + default_factory=list + ) + + def __call__(self, value, state: daglish.State): + self.values_and_paths.append((value, state.current_path)) + + def switch_buildables_to_args(value, state: Optional[daglish.State] = None): """Replaces buildables with their arguments dictionary. @@ -97,6 +107,28 @@ def switch_buildables_to_args(value, state: Optional[daglish.State] = None): return value +@dataclasses.dataclass +class MyRange: + start: int + end: int + + +def myrange_flatten_with_paths_fn(myrange: MyRange): + return ( + ({"value": i}, daglish.Index(i)) + for i in range(myrange.start, myrange.end) + ) + + +daglish.register_node_traverser( + MyRange, + flatten_fn=NotImplemented, # pytype: disable=wrong-arg-types + unflatten_fn=NotImplemented, # pytype: disable=wrong-arg-types + path_elements_fn=NotImplemented, # pytype: disable=wrong-arg-types + flatten_with_paths_fn=myrange_flatten_with_paths_fn, +) + + class PathTest(parameterized.TestCase): @parameterized.named_parameters( @@ -437,6 +469,73 @@ def test_argument_history(self): history.ChangeKind.NEW_VALUE) +_eager_map_fns = [ + lambda state, obj: state.map_children(obj), + lambda state, obj: list(state.yield_map_child_values(obj)), + lambda state, obj: state.flattened_map_children(obj), +] + + +class StateApiTest(parameterized.TestCase): + + @parameterized.parameters(_eager_map_fns) + def test_map_dict(self, map_fn): + obj = {"a": 1, "b": 2} + log_calls = NonRecursingLoggingFunction() + traversal = daglish.BasicTraversal(log_calls, obj) + state = traversal.initial_state() + map_fn(state, obj) + self.assertEqual( + log_calls.values_and_paths, + [(1, (daglish.Key(key="a"),)), (2, (daglish.Key(key="b"),))], + ) + + @parameterized.parameters(_eager_map_fns) + def test_map_tuple(self, map_fn): + obj = ((), (1, 2), 3) + log_calls = NonRecursingLoggingFunction() + traversal = daglish.BasicTraversal(log_calls, obj) + state = traversal.initial_state() + map_fn(state, obj) + self.assertEqual( + log_calls.values_and_paths, + [ + ((), (daglish.Index(index=0),)), + ((1, 2), (daglish.Index(index=1),)), + (3, (daglish.Index(index=2),)), + ], + ) + + @parameterized.parameters(_eager_map_fns) + def test_map_memoized(self, map_fn): + shared = {"foo": 123} + obj = [shared, shared, shared] + log_calls = NonRecursingLoggingFunction() + traversal = daglish.MemoizedTraversal(log_calls, obj) + state = traversal.initial_state() + map_fn(state, obj) + self.assertEqual( + log_calls.values_and_paths, [({"foo": 123}, (daglish.Index(index=0),))] + ) + + def test_fast_map_calls_flatten_with_paths(self): + obj = MyRange(3, 7) + log_calls = NonRecursingLoggingFunction() + traversal = daglish.MemoizedTraversal(log_calls, obj) + state = traversal.initial_state() + for _ in state.yield_map_child_values(obj): + pass # Run lazy iterator. + self.assertEqual( + log_calls.values_and_paths, + [ + ({"value": 3}, (daglish.Index(index=3),)), + ({"value": 4}, (daglish.Index(index=4),)), + ({"value": 5}, (daglish.Index(index=5),)), + ({"value": 6}, (daglish.Index(index=6),)), + ], + ) + + class ArgsSwitchingFuzzTest(parameterized.TestCase): def test_fuzz(self): @@ -595,7 +694,7 @@ def traverse(value, state: daglish.State): daglish.path_str(path) for path in state.get_all_paths(allow_caching=True) ), - "sub_values": state.flattened_map_children(value).values, + "sub_values": list(state.yield_map_child_values(value)), } else: return "leaf value" diff --git a/fiddle/_src/debug/grep.py b/fiddle/_src/debug/grep.py index b5b2bc95..38a03216 100644 --- a/fiddle/_src/debug/grep.py +++ b/fiddle/_src/debug/grep.py @@ -63,7 +63,8 @@ def grep( def traverse(value, state: daglish.State): path_str = daglish.path_str(state.current_path) if state.is_traversable(value): - state.flattened_map_children(value) + for _ in state.yield_map_child_values(value): + pass # Run lazy iterator. if isinstance(value, config_lib.Buildable): fn_or_cls = config_lib.get_callable(value) value_str = f"<{type(value).__name__}({fn_or_cls.__name__})>" diff --git a/fiddle/_src/experimental/visualize.py b/fiddle/_src/experimental/visualize.py index afb989c8..08ae7dbc 100644 --- a/fiddle/_src/experimental/visualize.py +++ b/fiddle/_src/experimental/visualize.py @@ -120,7 +120,7 @@ def _helper(value, substate: daglish.State) -> bool: for sub_path in substate.get_all_paths(): if not _goes_through_node(sub_path): return False - return all(substate.flattened_map_children(value).values) + return all(substate.yield_map_child_values(value)) # Creates a sub-traversal using a different function. We eventually might # make this part of the daglish API. @@ -199,8 +199,8 @@ def traverse(node, state: daglish.State) -> None: all_paths = state.get_all_paths(allow_caching=True) node_to_depth[id(node)] = min(_path_len(path) for path in all_paths) id_to_node[id(node)] = node - if state.is_traversable(node): - state.flattened_map_children(node) + for _ in state.yield_map_child_values(node, ignore_leaves=True): + pass # Run lazy iterator. daglish.MemoizedTraversal.run(traverse, config) diff --git a/fiddle/_src/graphviz.py b/fiddle/_src/graphviz.py index 8cb5cd8f..9ba81a2a 100644 --- a/fiddle/_src/graphviz.py +++ b/fiddle/_src/graphviz.py @@ -833,10 +833,10 @@ def visit(value, state: daglish.State): """Returns true if any child has changed.""" parents_changed = id(value) in changed_parent_ids if state.is_traversable(value): - subtraversal = state.flattened_map_children(value) - any_changed = any(subtraversal.values) + child_results = list(state.yield_map_child_values(value)) + any_changed = any(child_results) if isinstance(value, dict) and id(value) in old_value_ids: - _trim_dict(value, subtraversal.values) + _trim_dict(value, child_results) return any_changed or parents_changed elif isinstance(value, _ChangedValue): state.call(value.old_value, daglish.Attr('old_value')) @@ -866,7 +866,8 @@ def _find_mutable_values_with_changed_parents(structure_with_changed_values): def visit(value, state: daglish.State): if state.is_traversable(value): - state.flattened_map_children(value) + for _ in state.yield_map_child_values(value): + pass # Run lazy iterator. elif isinstance(value, _ChangedValue): assert value.old_value is not value.new_value if daglish.is_memoizable(value.old_value): diff --git a/fiddle/_src/materialize.py b/fiddle/_src/materialize.py index 11262bd2..02adb908 100644 --- a/fiddle/_src/materialize.py +++ b/fiddle/_src/materialize.py @@ -55,7 +55,7 @@ def traverse(node, state: daglish.State): for arg in node.__signature_info__.parameters.values(): if arg.default is not arg.empty and arg.name not in node.__arguments__: setattr(node, arg.name, arg.default) - if state.is_traversable(node): - state.flattened_map_children(node) + for _ in state.yield_map_child_values(node, ignore_leaves=True): + pass # Run lazy iterator. daglish.MemoizedTraversal.run(traverse, value) diff --git a/fiddle/_src/partial.py b/fiddle/_src/partial.py index 6d386934..b8d4b569 100644 --- a/fiddle/_src/partial.py +++ b/fiddle/_src/partial.py @@ -45,11 +45,11 @@ class _BuiltArgFactory: def _contains_arg_factory(value: Any) -> bool: """Returns true if ``value`` contains any ``_BuiltArgFactory`` instances.""" - def visit(node, state): + def visit(node, state: daglish.State): if isinstance(node, _BuiltArgFactory): return True elif state.is_traversable(node): - return any(state.flattened_map_children(node).values) + return any(state.yield_map_child_values(node)) else: return False diff --git a/fiddle/_src/printing.py b/fiddle/_src/printing.py index 859923d7..d4bb5b9c 100644 --- a/fiddle/_src/printing.py +++ b/fiddle/_src/printing.py @@ -45,9 +45,9 @@ def __repr__(self) -> str: def _has_nested_builder(value: Any, state=None) -> bool: state = state or daglish.MemoizedTraversal.begin(_has_nested_builder, value) - return (isinstance(value, config.Buildable) or - (state.is_traversable(value) and - any(state.flattened_map_children(value).values))) + return isinstance(value, config.Buildable) or ( + state.is_traversable(value) and any(state.yield_map_child_values(value)) + ) def _path_str(path: daglish.Path) -> str: @@ -205,7 +205,7 @@ def generate(value, state=None) -> Iterator[_LeafSetting]: else: # value must be a Buildable or a traversable containing a Buidable. assert state.is_traversable(value) - for sub_result in state.flattened_map_children(value).values: + for sub_result in state.yield_map_child_values(value): yield from sub_result # Used in format_line below. The use of getattr and the dummy type default @@ -288,7 +288,7 @@ def traverse(value, state=None) -> Iterator[str]: yield f'{_path_str(path)} = <[unset]>' elif state.is_traversable(value): - for sub_result in state.flattened_map_children(value).values: + for sub_result in state.yield_map_child_values(value): yield from sub_result else: yield f'{_path_str(state.current_path)} = {value}' @@ -359,7 +359,7 @@ def dict_generate(value, state=None) -> Iterator[_LeafSetting]: else: # value must be a Buildable or a traversable containing a Buildable. assert state.is_traversable(value) - for sub_result in state.flattened_map_children(value).values: + for sub_result in state.yield_map_child_values(value): yield from sub_result args_dict = {} diff --git a/fiddle/_src/selectors.py b/fiddle/_src/selectors.py index c56bd69a..94a9e323 100644 --- a/fiddle/_src/selectors.py +++ b/fiddle/_src/selectors.py @@ -81,7 +81,7 @@ def _memoized_walk_leaves_first(value, state=None): state = state or daglish.MemoizedTraversal.begin(_memoized_walk_leaves_first, value) if state.is_traversable(value): - for sub_result in state.flattened_map_children(value).values: + for sub_result in state.yield_map_child_values(value): yield from sub_result yield value diff --git a/fiddle/_src/validation/baseline_style.py b/fiddle/_src/validation/baseline_style.py index 077e0a73..70558585 100644 --- a/fiddle/_src/validation/baseline_style.py +++ b/fiddle/_src/validation/baseline_style.py @@ -142,8 +142,8 @@ def traverse(value: Any, state: daglish.State) -> Any: location: history.Location = history_entry.location sample_paths_by_file.setdefault(location.filename, path_str) - if state.is_traversable(value): - state.flattened_map_children(value) + for _ in state.yield_map_child_values(value, ignore_leaves=True): + pass # Run lazy iterator. daglish.MemoizedTraversal.run(traverse, config)