Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fiddle/_src/codegen/auto_config/complex_to_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def traverse(value, state: daglish.State) -> int:
elif not state.is_traversable(value):
return 1
else:
sub_values = state.flattened_map_children(value)
return 1 + sum(sub_values.values)
return 1 + sum(state.yield_map_child_values(value))

return lambda x: daglish.MemoizedTraversal.run(traverse, x) > level

Expand Down
10 changes: 5 additions & 5 deletions fiddle/_src/codegen/legacy_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def traverse(value, state: daglish.State):
if isinstance(value, config_lib.Buildable):
to_count[id(value)] += 1
children_by_id[id(value)] = value
if state.is_traversable(value):
state.flattened_map_children(value)
for _ in state.yield_map_child_values(value, ignore_leaves=True):
pass # Run lazy iterator.

daglish.BasicTraversal.run(traverse, buildable)
return [
Expand Down Expand Up @@ -194,10 +194,10 @@ def _configure_shared_objects(
variable_name_prefix: Prefix for any variables introduced.
"""

def traverse(child, state):
def traverse(child, state: daglish.State):
"""Generates code for a shared instance."""
if state.is_traversable(child):
state.flattened_map_children(child)
for _ in state.yield_map_child_values(child, ignore_leaves=True):
pass # Run lazy iterator.
if isinstance(child, config_lib.Buildable):
# Name this better..
name = shared_manager.namespace.get_new_name(
Expand Down
75 changes: 67 additions & 8 deletions fiddle/_src/daglish.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class NamedTupleType:
# This has the same type as Path, but different semantic meaning.
PathElements = Tuple[PathElement, ...]
PathElementsFn = Callable[[Any], PathElements]
_ValueAndPath = Tuple[Any, PathElement]
OptimizedFlattenFn = Callable[[Any], Iterable[_ValueAndPath]]
FlattenFn = Callable[[Any], Tuple[Tuple[Any, ...], Any]]
UnflattenFn = Callable[[Iterable[Any], Any], Any]

Expand All @@ -155,6 +157,7 @@ class NodeTraverser:
flatten: FlattenFn
unflatten: UnflattenFn
path_elements: PathElementsFn
flatten_with_paths: OptimizedFlattenFn | None = None


class NodeTraverserRegistry:
Expand Down Expand Up @@ -185,6 +188,7 @@ def register_node_traverser(
flatten_fn: FlattenFn,
unflatten_fn: UnflattenFn,
path_elements_fn: PathElementsFn,
flatten_with_paths_fn: OptimizedFlattenFn | None = None,
) -> None:
"""Registers a node traverser for `node_type`.

Expand All @@ -202,6 +206,9 @@ def register_node_traverser(
flattened values returned by `flatten_fn`. This should accept an
instance of `node_type`, and return a sequence of `PathElement`s aligned
with the values returned by `flatten_fn`.
flatten_with_paths_fn: A version of `flatten_fn` that returns an iterable
of `(value, path)` pairs, where `value` is a child value and `path` is a
`Path` to the value.
"""
if not isinstance(node_type, type):
raise TypeError(f"`node_type` ({node_type}) must be a type.")
Expand All @@ -212,6 +219,7 @@ def register_node_traverser(
flatten=flatten_fn,
unflatten=unflatten_fn,
path_elements=path_elements_fn,
flatten_with_paths=flatten_with_paths_fn,
)

def find_node_traverser(
Expand Down Expand Up @@ -282,7 +290,9 @@ def unflatten_defaultdict(values, metadata):
tuple,
flatten_fn=lambda x: (x, None),
unflatten_fn=lambda x, _: tuple(x),
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))))
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))),
flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)),
)

register_node_traverser(
NamedTupleType,
Expand All @@ -294,7 +304,9 @@ def unflatten_defaultdict(values, metadata):
list,
flatten_fn=lambda x: (tuple(x), None),
unflatten_fn=lambda x, _: list(x),
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))))
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))),
flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)),
)


def is_prefix(prefix_path: Path, containing_path: Path):
Expand Down Expand Up @@ -607,9 +619,57 @@ def flattened_map_children(self, value: Any) -> SubTraversalResult:
"""
node_traverser = self.traversal.find_node_traverser(type(value))
if node_traverser is None:
raise ValueError("Please handle non-traversable values yourself.")
raise ValueError(
f"No node traverser found for {value}, please register traverser"
" or pre-process the non-traversable values first."
)
return self._flattened_map_children(value, node_traverser)

def yield_map_child_values(
self, value: Any, ignore_leaves: bool = False
) -> Iterable[Any]:
"""Maps over children for traversable values, but doesn't unflatten results.

This method only returns result values, so use it in place of
`state.flattened_map_children(value).values` when the result is consumed
once. If you are calling this for an "effectful" computation (including ones
that just gather results in `nonlocal` variables) then consider the
following pattern,

for _ in state.yield_map_child_values(node, ignore_leaves=True):
pass # Run lazy iterator.

Args:
value: Value to map over.
ignore_leaves: If True, then this function will return an empty iterable
if `value` is not traversable. Otherwise, it will raise a ValueError.

Yields:
Sub-traversal results, the same type as returned by your _traverse
function.

Raises:
ValueError: If `value` is not traversable and `ignore_leaves` is `False`.
Please test beforehand by calling `state.is_traversable()`.
"""
node_traverser = self.traversal.find_node_traverser(type(value))
if node_traverser is None:
if ignore_leaves:
return
else:
raise ValueError(
f"No node traverser found for {value}, please register traverser"
" or pre-process the non-traversable values first."
)
if node_traverser.flatten_with_paths is not None:
for value, path in node_traverser.flatten_with_paths(value):
yield self.call(value, path)
else:
sub_values, unused_meta = node_traverser.flatten(value)
path_elements = node_traverser.path_elements(value)
for sub_value, path_element in zip(sub_values, path_elements):
yield self.call(sub_value, path_element)

def call(self, value, *additional_path: PathElement):
"""Low-level function to execute a sub-traversal.

Expand Down Expand Up @@ -755,8 +815,8 @@ def collect_paths_by_id(
def traverse(value, state: State):
if not memoizable_only or is_memoizable(value):
paths_by_id.setdefault(id(value), []).append(state.current_path)
if state.is_traversable(value):
state.flattened_map_children(value)
for _ in state.yield_map_child_values(value, ignore_leaves=True):
pass # Run lazy iterator.

traversal = BasicTraversal(traverse, structure, registry=registry)
traverse(structure, traversal.initial_state())
Expand Down Expand Up @@ -794,9 +854,8 @@ def iterate(

def _traverse(node, state: State):
yield node, state.current_path
if state.is_traversable(node):
for sub_result in state.flattened_map_children(node).values:
yield from sub_result
for sub_result in state.yield_map_child_values(node, ignore_leaves=True):
yield from sub_result

if memoized:
traversal = MemoizedTraversal(
Expand Down
103 changes: 101 additions & 2 deletions fiddle/_src/daglish_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import enum
import json
import random
from typing import Any, List, NamedTuple, Optional, cast
from typing import Any, List, NamedTuple, Optional, Tuple, cast

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -76,6 +76,16 @@ def __call__(self, value, state: Optional[daglish.State] = None):
return state.map_children(value)


@dataclasses.dataclass
class NonRecursingLoggingFunction:
values_and_paths: List[Tuple[Any, Any]] = dataclasses.field(
default_factory=list
)

def __call__(self, value, state: daglish.State):
self.values_and_paths.append((value, state.current_path))


def switch_buildables_to_args(value, state: Optional[daglish.State] = None):
"""Replaces buildables with their arguments dictionary.

Expand All @@ -97,6 +107,28 @@ def switch_buildables_to_args(value, state: Optional[daglish.State] = None):
return value


@dataclasses.dataclass
class MyRange:
start: int
end: int


def myrange_flatten_with_paths_fn(myrange: MyRange):
return (
({"value": i}, daglish.Index(i))
for i in range(myrange.start, myrange.end)
)


daglish.register_node_traverser(
MyRange,
flatten_fn=NotImplemented, # pytype: disable=wrong-arg-types
unflatten_fn=NotImplemented, # pytype: disable=wrong-arg-types
path_elements_fn=NotImplemented, # pytype: disable=wrong-arg-types
flatten_with_paths_fn=myrange_flatten_with_paths_fn,
)


class PathTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -437,6 +469,73 @@ def test_argument_history(self):
history.ChangeKind.NEW_VALUE)


_eager_map_fns = [
lambda state, obj: state.map_children(obj),
lambda state, obj: list(state.yield_map_child_values(obj)),
lambda state, obj: state.flattened_map_children(obj),
]


class StateApiTest(parameterized.TestCase):

@parameterized.parameters(_eager_map_fns)
def test_map_dict(self, map_fn):
obj = {"a": 1, "b": 2}
log_calls = NonRecursingLoggingFunction()
traversal = daglish.BasicTraversal(log_calls, obj)
state = traversal.initial_state()
map_fn(state, obj)
self.assertEqual(
log_calls.values_and_paths,
[(1, (daglish.Key(key="a"),)), (2, (daglish.Key(key="b"),))],
)

@parameterized.parameters(_eager_map_fns)
def test_map_tuple(self, map_fn):
obj = ((), (1, 2), 3)
log_calls = NonRecursingLoggingFunction()
traversal = daglish.BasicTraversal(log_calls, obj)
state = traversal.initial_state()
map_fn(state, obj)
self.assertEqual(
log_calls.values_and_paths,
[
((), (daglish.Index(index=0),)),
((1, 2), (daglish.Index(index=1),)),
(3, (daglish.Index(index=2),)),
],
)

@parameterized.parameters(_eager_map_fns)
def test_map_memoized(self, map_fn):
shared = {"foo": 123}
obj = [shared, shared, shared]
log_calls = NonRecursingLoggingFunction()
traversal = daglish.MemoizedTraversal(log_calls, obj)
state = traversal.initial_state()
map_fn(state, obj)
self.assertEqual(
log_calls.values_and_paths, [({"foo": 123}, (daglish.Index(index=0),))]
)

def test_fast_map_calls_flatten_with_paths(self):
obj = MyRange(3, 7)
log_calls = NonRecursingLoggingFunction()
traversal = daglish.MemoizedTraversal(log_calls, obj)
state = traversal.initial_state()
for _ in state.yield_map_child_values(obj):
pass # Run lazy iterator.
self.assertEqual(
log_calls.values_and_paths,
[
({"value": 3}, (daglish.Index(index=3),)),
({"value": 4}, (daglish.Index(index=4),)),
({"value": 5}, (daglish.Index(index=5),)),
({"value": 6}, (daglish.Index(index=6),)),
],
)


class ArgsSwitchingFuzzTest(parameterized.TestCase):

def test_fuzz(self):
Expand Down Expand Up @@ -595,7 +694,7 @@ def traverse(value, state: daglish.State):
daglish.path_str(path)
for path in state.get_all_paths(allow_caching=True)
),
"sub_values": state.flattened_map_children(value).values,
"sub_values": list(state.yield_map_child_values(value)),
}
else:
return "leaf value"
Expand Down
3 changes: 2 additions & 1 deletion fiddle/_src/debug/grep.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def grep(
def traverse(value, state: daglish.State):
path_str = daglish.path_str(state.current_path)
if state.is_traversable(value):
state.flattened_map_children(value)
for _ in state.yield_map_child_values(value):
pass # Run lazy iterator.
if isinstance(value, config_lib.Buildable):
fn_or_cls = config_lib.get_callable(value)
value_str = f"<{type(value).__name__}({fn_or_cls.__name__})>"
Expand Down
6 changes: 3 additions & 3 deletions fiddle/_src/experimental/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _helper(value, substate: daglish.State) -> bool:
for sub_path in substate.get_all_paths():
if not _goes_through_node(sub_path):
return False
return all(substate.flattened_map_children(value).values)
return all(substate.yield_map_child_values(value))

# Creates a sub-traversal using a different function. We eventually might
# make this part of the daglish API.
Expand Down Expand Up @@ -199,8 +199,8 @@ def traverse(node, state: daglish.State) -> None:
all_paths = state.get_all_paths(allow_caching=True)
node_to_depth[id(node)] = min(_path_len(path) for path in all_paths)
id_to_node[id(node)] = node
if state.is_traversable(node):
state.flattened_map_children(node)
for _ in state.yield_map_child_values(node, ignore_leaves=True):
pass # Run lazy iterator.

daglish.MemoizedTraversal.run(traverse, config)

Expand Down
9 changes: 5 additions & 4 deletions fiddle/_src/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,10 @@ def visit(value, state: daglish.State):
"""Returns true if any child has changed."""
parents_changed = id(value) in changed_parent_ids
if state.is_traversable(value):
subtraversal = state.flattened_map_children(value)
any_changed = any(subtraversal.values)
child_results = list(state.yield_map_child_values(value))
any_changed = any(child_results)
if isinstance(value, dict) and id(value) in old_value_ids:
_trim_dict(value, subtraversal.values)
_trim_dict(value, child_results)
return any_changed or parents_changed
elif isinstance(value, _ChangedValue):
state.call(value.old_value, daglish.Attr('old_value'))
Expand Down Expand Up @@ -866,7 +866,8 @@ def _find_mutable_values_with_changed_parents(structure_with_changed_values):

def visit(value, state: daglish.State):
if state.is_traversable(value):
state.flattened_map_children(value)
for _ in state.yield_map_child_values(value):
pass # Run lazy iterator.
elif isinstance(value, _ChangedValue):
assert value.old_value is not value.new_value
if daglish.is_memoizable(value.old_value):
Expand Down
4 changes: 2 additions & 2 deletions fiddle/_src/materialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def traverse(node, state: daglish.State):
for arg in node.__signature_info__.parameters.values():
if arg.default is not arg.empty and arg.name not in node.__arguments__:
setattr(node, arg.name, arg.default)
if state.is_traversable(node):
state.flattened_map_children(node)
for _ in state.yield_map_child_values(node, ignore_leaves=True):
pass # Run lazy iterator.

daglish.MemoizedTraversal.run(traverse, value)
4 changes: 2 additions & 2 deletions fiddle/_src/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ class _BuiltArgFactory:
def _contains_arg_factory(value: Any) -> bool:
"""Returns true if ``value`` contains any ``_BuiltArgFactory`` instances."""

def visit(node, state):
def visit(node, state: daglish.State):
if isinstance(node, _BuiltArgFactory):
return True
elif state.is_traversable(node):
return any(state.flattened_map_children(node).values)
return any(state.yield_map_child_values(node))
else:
return False

Expand Down
Loading