Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions fiddle/_src/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import inspect
import types
from typing import Any, Iterator, List, Optional, Type
from typing import Any, Dict, Iterator, List, Optional, Type

from fiddle._src import config
from fiddle._src import daglish
Expand Down Expand Up @@ -131,8 +131,9 @@ def _get_tags(cfg, path):
return None


def _rearrange_buildable_args_and_insert_unset_sentinels(
value: config.Buildable) -> config.Buildable:
def _rearrange_buildable_args(
value: config.Buildable, insert_unset_sentinels: bool = True
) -> config.Buildable:
"""Returns a copy of a Buildable with normalized arguments.

This normalizes arguments by re-creating the __arguments__ dictionary in the
Expand All @@ -141,6 +142,8 @@ def _rearrange_buildable_args_and_insert_unset_sentinels(

Args:
value: Buildable to copy and normalize.
insert_unset_sentinels: If true, insert unset sentinels to arguments as the
default values.

Returns:
Copy of `value` with arguments normalized.
Expand All @@ -155,7 +158,7 @@ def _rearrange_buildable_args_and_insert_unset_sentinels(
continue
elif param_name in old_arguments:
new_arguments[param_name] = old_arguments.pop(param_name)
else:
elif insert_unset_sentinels:
new_arguments[param_name] = _UnsetValue(param)
new_arguments.update(old_arguments) # Add in kwargs, in current order.
object.__setattr__(value, '__arguments__', new_arguments)
Expand Down Expand Up @@ -190,7 +193,7 @@ def generate(value, state=None) -> Iterator[_LeafSetting]:

# Rearrange parameters in signature order, and add "unset" sentinels.
if isinstance(value, config.Buildable):
value = _rearrange_buildable_args_and_insert_unset_sentinels(value)
value = _rearrange_buildable_args(value)

if isinstance(value, tagging.TaggedValueCls):
value = _TaggedValueWrapper(value)
Expand Down Expand Up @@ -331,3 +334,43 @@ def make_previous_text(entry: history.HistoryEntry) -> str:
value_history[-1].new_value, raw_value_repr=raw_value_repr)
current = f'{current_value} @ {value_history[-1].location}'
return f'{_path_str(path)} = {current}{past}'


def as_dict_flattened(cfg: config.Buildable) -> Dict[str, Any]:
"""Returns a flattened dict of cfg's paths (dot syntax) and values.

Default values won't be included in the flattened dict.

Args:
cfg: A buildable to generate a string representation for.

Returns: A flattened Dict representation of `cfg`.
"""

def dict_generate(value, state=None) -> Iterator[_LeafSetting]:
state = state or daglish.BasicTraversal.begin(dict_generate, value)

tags = _get_tags(cfg, state.current_path)
if tags:
value = tagging.TaggedValue(tags=tags, default=value)

# Rearrange parameters in signature order, and add "unset" sentinels.
if isinstance(value, config.Buildable):
value = _rearrange_buildable_args(value, insert_unset_sentinels=False)

if isinstance(value, tagging.TaggedValueCls):
value = _TaggedValueWrapper(value)
yield _LeafSetting(state.current_path, None, value)
elif not _has_nested_builder(value):
yield _LeafSetting(state.current_path, None, value)
else:
# value must be a Buildable or a traversable containing a Buidable.
assert state.is_traversable(value)
for sub_result in state.flattened_map_children(value).values:
yield from sub_result

args_dict = {}
for leaf in dict_generate(cfg):
args_dict[_path_str(leaf.path)] = leaf.value

return args_dict
160 changes: 160 additions & 0 deletions fiddle/_src/printing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,165 @@ def test_collection_of_two_buildables_history(self):
self.assertRegex(output, expected)


class AsFlattenedDictTests(absltest.TestCase):

def test_simple_flattened_dict(self):
cfg = fdl.Config(fn_x_y, 1, 'abc')
output = printing.as_dict_flattened(cfg)

expected = {'x': 1, 'y': 'abc'}
self.assertEqual(output, expected)

def test_skip_unset_argument(self):
cfg = fdl.Config(fn_x_y, 3.14)
output = printing.as_dict_flattened(cfg)

expected = {'x': 3.14}
self.assertEqual(output, expected)

def test_nested(self):
cfg = fdl.Config(fn_x_y, 'x', fdl.Config(fn_x_y, 'nest_x', 123))
output = printing.as_dict_flattened(cfg)

expected = {'x': 'x', 'y.x': 'nest_x', 'y.y': 123}
self.assertEqual(output, expected)

def test_class(self):
cfg = fdl.Config(SampleClass, 'a_param', b=123)
output = printing.as_dict_flattened(cfg)

expected = {'a': 'a_param', 'b': 123}
self.assertEqual(output, expected)

def test_kwargs(self):
cfg = fdl.Config(fn_with_kwargs, 1, abc='extra kwarg value')
output = printing.as_dict_flattened(cfg)

expected = {'x': 1, 'abc': 'extra kwarg value'}
self.assertEqual(output, expected)

def test_nested_kwargs(self):
cfg = fdl.Config(
fn_with_kwargs, extra=fdl.Config(fn_with_kwargs, 1, nested_extra='whee')
)
output = printing.as_dict_flattened(cfg)

expected = {'extra.x': 1, 'extra.nested_extra': 'whee'}
self.assertEqual(output, expected)

def test_nested_collections(self):
cfg = fdl.Config(
fn_x_y, [fdl.Config(fn_x_y, 1, '1'), fdl.Config(SampleClass, 2)]
)
output = printing.as_dict_flattened(cfg)

expected = {'x[0].x': 1, 'x[0].y': '1', 'x[1].a': 2}
self.assertEqual(output, expected)

def test_multiple_nested_collections(self):
cfg = fdl.Config(
fn_x_y,
{'a': fdl.Config(fn_with_kwargs, abc=[1, 2, 3]), 'b': [3, 2, 1]},
[fdl.Config(fn_x_y, [fdl.Config(fn_x_y, 1, 2)])],
)
output = printing.as_dict_flattened(cfg)

expected = {
"x['a'].abc": [1, 2, 3],
"x['b']": [3, 2, 1],
'y[0].x[0].x': 1,
'y[0].x[0].y': 2,
}
self.assertEqual(output, expected)

def test_skip_default_values(self):
def test_fn(w, x, y=3, z='abc'): # pylint: disable=unused-argument
pass

cfg = fdl.Config(test_fn, 1)
output = printing.as_dict_flattened(cfg)

expected = {'w': 1}
self.assertEqual(output, expected)

def test_tagged_values(self):
cfg = fdl.Config(fn_x_y, x=SampleTag.new(), y=SampleTag.new(default='abc'))
output = printing.as_dict_flattened(cfg)

expected = "'abc' #__main__.SampleTag"
self.assertEqual(repr(output['y']), expected)

fdl.set_tagged(cfg, tag=SampleTag, value='cba')
output = printing.as_dict_flattened(cfg)

expected = "'cba' #__main__.SampleTag"
self.assertEqual(repr(output['x']), expected)
self.assertEqual(repr(output['y']), expected)

def test_partial(self):
partial = fdl.Partial(fn_x_y)
partial.x = 'abc'
output = printing.as_dict_flattened(partial)

expected = {'x': 'abc'}
self.assertEqual(output, expected)

def test_builtin_types_annotations(self):
cfg = fdl.Config(fn_with_type_annotations, 1)
cfg.y = 'abc'
output = printing.as_dict_flattened(cfg)

expected = {'x': 1, 'y': 'abc'}
self.assertEqual(output, expected)

def test_annotated_kwargs(self):
cfg = fdl.Config(annotated_kwargs_helper, x=1, y='oops')
output = printing.as_dict_flattened(cfg)

expected = {'x': 1, 'y': 'oops'}
self.assertEqual(output, expected)

def test_disabling_type_annotations(self):
cfg = fdl.Config(fn_with_type_annotations, 1)
cfg.y = 'abc'
output = printing.as_dict_flattened(cfg)

expected = {'x': 1, 'y': 'abc'}
self.assertEqual(output, expected)

def test_union_type(self):
def to_integer(x: Union[int, str]):
return int(x)

cfg = fdl.Config(to_integer, 1)
output = printing.as_dict_flattened(cfg)

expected = {'x': 1}
self.assertEqual(output, expected)

def test_parameterized_generic(self):
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
self.skipTest('types.GenericAlias is 3.9+ only.')

def takes_list(x: list[int]):
return x

cfg = fdl.Config(takes_list, [1, 2, 3])
output = printing.as_dict_flattened(cfg)

expected = {'x': [1, 2, 3]}
self.assertEqual(output, expected)

def test_materialized_default_values(self):
def test_fn(w, x, y=3, z='abc'):
del w, x, y, z # Unused.

cfg = fdl.Config(test_fn, 1)
fdl.materialize_defaults(cfg)
output = printing.as_dict_flattened(cfg)
expected = {'w': 1, 'y': 3, 'z': 'abc'}
self.assertEqual(output, expected)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions fiddle/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
"""Functions to output representations of `fdl.Buildable`s."""

# pylint: disable=unused-import
from fiddle._src.printing import as_dict_flattened
from fiddle._src.printing import as_str_flattened
from fiddle._src.printing import history_per_leaf_parameter