From 1bc59f402784708ab82278c3c055a1fec5e34ed8 Mon Sep 17 00:00:00 2001 From: Fiddle-Config Team Date: Mon, 25 Sep 2023 23:00:48 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 568426297 --- fiddle/_src/printing.py | 53 ++++++++++-- fiddle/_src/printing_test.py | 160 +++++++++++++++++++++++++++++++++++ fiddle/printing.py | 1 + 3 files changed, 209 insertions(+), 5 deletions(-) diff --git a/fiddle/_src/printing.py b/fiddle/_src/printing.py index 5d4383d5..3eecc711 100644 --- a/fiddle/_src/printing.py +++ b/fiddle/_src/printing.py @@ -19,7 +19,7 @@ import dataclasses import inspect import types -from typing import Any, Iterator, List, Optional, Type +from typing import Any, Dict, Iterator, List, Optional, Type from fiddle._src import config from fiddle._src import daglish @@ -131,8 +131,9 @@ def _get_tags(cfg, path): return None -def _rearrange_buildable_args_and_insert_unset_sentinels( - value: config.Buildable) -> config.Buildable: +def _rearrange_buildable_args( + value: config.Buildable, insert_unset_sentinels: bool = True +) -> config.Buildable: """Returns a copy of a Buildable with normalized arguments. This normalizes arguments by re-creating the __arguments__ dictionary in the @@ -141,6 +142,8 @@ def _rearrange_buildable_args_and_insert_unset_sentinels( Args: value: Buildable to copy and normalize. + insert_unset_sentinels: If true, insert unset sentinels to arguments as the + default values. Returns: Copy of `value` with arguments normalized. @@ -155,7 +158,7 @@ def _rearrange_buildable_args_and_insert_unset_sentinels( continue elif param_name in old_arguments: new_arguments[param_name] = old_arguments.pop(param_name) - else: + elif insert_unset_sentinels: new_arguments[param_name] = _UnsetValue(param) new_arguments.update(old_arguments) # Add in kwargs, in current order. object.__setattr__(value, '__arguments__', new_arguments) @@ -190,7 +193,7 @@ def generate(value, state=None) -> Iterator[_LeafSetting]: # Rearrange parameters in signature order, and add "unset" sentinels. if isinstance(value, config.Buildable): - value = _rearrange_buildable_args_and_insert_unset_sentinels(value) + value = _rearrange_buildable_args(value) if isinstance(value, tagging.TaggedValueCls): value = _TaggedValueWrapper(value) @@ -331,3 +334,43 @@ def make_previous_text(entry: history.HistoryEntry) -> str: value_history[-1].new_value, raw_value_repr=raw_value_repr) current = f'{current_value} @ {value_history[-1].location}' return f'{_path_str(path)} = {current}{past}' + + +def as_dict_flattened(cfg: config.Buildable) -> Dict[str, Any]: + """Returns a flattened dict of cfg's paths (dot syntax) and values. + + Default values won't be included in the flattened dict. + + Args: + cfg: A buildable to generate a string representation for. + + Returns: A flattened Dict representation of `cfg`. + """ + + def dict_generate(value, state=None) -> Iterator[_LeafSetting]: + state = state or daglish.BasicTraversal.begin(dict_generate, value) + + tags = _get_tags(cfg, state.current_path) + if tags: + value = tagging.TaggedValue(tags=tags, default=value) + + # Rearrange parameters in signature order, and add "unset" sentinels. + if isinstance(value, config.Buildable): + value = _rearrange_buildable_args(value, insert_unset_sentinels=False) + + if isinstance(value, tagging.TaggedValueCls): + value = _TaggedValueWrapper(value) + yield _LeafSetting(state.current_path, None, value) + elif not _has_nested_builder(value): + yield _LeafSetting(state.current_path, None, value) + else: + # value must be a Buildable or a traversable containing a Buidable. + assert state.is_traversable(value) + for sub_result in state.flattened_map_children(value).values: + yield from sub_result + + args_dict = {} + for leaf in dict_generate(cfg): + args_dict[_path_str(leaf.path)] = leaf.value + + return args_dict diff --git a/fiddle/_src/printing_test.py b/fiddle/_src/printing_test.py index f7dd619f..3c02bf48 100644 --- a/fiddle/_src/printing_test.py +++ b/fiddle/_src/printing_test.py @@ -473,5 +473,165 @@ def test_collection_of_two_buildables_history(self): self.assertRegex(output, expected) +class AsFlattenedDictTests(absltest.TestCase): + + def test_simple_flattened_dict(self): + cfg = fdl.Config(fn_x_y, 1, 'abc') + output = printing.as_dict_flattened(cfg) + + expected = {'x': 1, 'y': 'abc'} + self.assertEqual(output, expected) + + def test_skip_unset_argument(self): + cfg = fdl.Config(fn_x_y, 3.14) + output = printing.as_dict_flattened(cfg) + + expected = {'x': 3.14} + self.assertEqual(output, expected) + + def test_nested(self): + cfg = fdl.Config(fn_x_y, 'x', fdl.Config(fn_x_y, 'nest_x', 123)) + output = printing.as_dict_flattened(cfg) + + expected = {'x': 'x', 'y.x': 'nest_x', 'y.y': 123} + self.assertEqual(output, expected) + + def test_class(self): + cfg = fdl.Config(SampleClass, 'a_param', b=123) + output = printing.as_dict_flattened(cfg) + + expected = {'a': 'a_param', 'b': 123} + self.assertEqual(output, expected) + + def test_kwargs(self): + cfg = fdl.Config(fn_with_kwargs, 1, abc='extra kwarg value') + output = printing.as_dict_flattened(cfg) + + expected = {'x': 1, 'abc': 'extra kwarg value'} + self.assertEqual(output, expected) + + def test_nested_kwargs(self): + cfg = fdl.Config( + fn_with_kwargs, extra=fdl.Config(fn_with_kwargs, 1, nested_extra='whee') + ) + output = printing.as_dict_flattened(cfg) + + expected = {'extra.x': 1, 'extra.nested_extra': 'whee'} + self.assertEqual(output, expected) + + def test_nested_collections(self): + cfg = fdl.Config( + fn_x_y, [fdl.Config(fn_x_y, 1, '1'), fdl.Config(SampleClass, 2)] + ) + output = printing.as_dict_flattened(cfg) + + expected = {'x[0].x': 1, 'x[0].y': '1', 'x[1].a': 2} + self.assertEqual(output, expected) + + def test_multiple_nested_collections(self): + cfg = fdl.Config( + fn_x_y, + {'a': fdl.Config(fn_with_kwargs, abc=[1, 2, 3]), 'b': [3, 2, 1]}, + [fdl.Config(fn_x_y, [fdl.Config(fn_x_y, 1, 2)])], + ) + output = printing.as_dict_flattened(cfg) + + expected = { + "x['a'].abc": [1, 2, 3], + "x['b']": [3, 2, 1], + 'y[0].x[0].x': 1, + 'y[0].x[0].y': 2, + } + self.assertEqual(output, expected) + + def test_skip_default_values(self): + def test_fn(w, x, y=3, z='abc'): # pylint: disable=unused-argument + pass + + cfg = fdl.Config(test_fn, 1) + output = printing.as_dict_flattened(cfg) + + expected = {'w': 1} + self.assertEqual(output, expected) + + def test_tagged_values(self): + cfg = fdl.Config(fn_x_y, x=SampleTag.new(), y=SampleTag.new(default='abc')) + output = printing.as_dict_flattened(cfg) + + expected = "'abc' #__main__.SampleTag" + self.assertEqual(repr(output['y']), expected) + + fdl.set_tagged(cfg, tag=SampleTag, value='cba') + output = printing.as_dict_flattened(cfg) + + expected = "'cba' #__main__.SampleTag" + self.assertEqual(repr(output['x']), expected) + self.assertEqual(repr(output['y']), expected) + + def test_partial(self): + partial = fdl.Partial(fn_x_y) + partial.x = 'abc' + output = printing.as_dict_flattened(partial) + + expected = {'x': 'abc'} + self.assertEqual(output, expected) + + def test_builtin_types_annotations(self): + cfg = fdl.Config(fn_with_type_annotations, 1) + cfg.y = 'abc' + output = printing.as_dict_flattened(cfg) + + expected = {'x': 1, 'y': 'abc'} + self.assertEqual(output, expected) + + def test_annotated_kwargs(self): + cfg = fdl.Config(annotated_kwargs_helper, x=1, y='oops') + output = printing.as_dict_flattened(cfg) + + expected = {'x': 1, 'y': 'oops'} + self.assertEqual(output, expected) + + def test_disabling_type_annotations(self): + cfg = fdl.Config(fn_with_type_annotations, 1) + cfg.y = 'abc' + output = printing.as_dict_flattened(cfg) + + expected = {'x': 1, 'y': 'abc'} + self.assertEqual(output, expected) + + def test_union_type(self): + def to_integer(x: Union[int, str]): + return int(x) + + cfg = fdl.Config(to_integer, 1) + output = printing.as_dict_flattened(cfg) + + expected = {'x': 1} + self.assertEqual(output, expected) + + def test_parameterized_generic(self): + if not (sys.version_info.major == 3 and sys.version_info.minor >= 9): + self.skipTest('types.GenericAlias is 3.9+ only.') + + def takes_list(x: list[int]): + return x + + cfg = fdl.Config(takes_list, [1, 2, 3]) + output = printing.as_dict_flattened(cfg) + + expected = {'x': [1, 2, 3]} + self.assertEqual(output, expected) + + def test_materialized_default_values(self): + def test_fn(w, x, y=3, z='abc'): + del w, x, y, z # Unused. + + cfg = fdl.Config(test_fn, 1) + fdl.materialize_defaults(cfg) + output = printing.as_dict_flattened(cfg) + expected = {'w': 1, 'y': 3, 'z': 'abc'} + self.assertEqual(output, expected) + + if __name__ == '__main__': absltest.main() diff --git a/fiddle/printing.py b/fiddle/printing.py index c67495b8..4b42aa01 100644 --- a/fiddle/printing.py +++ b/fiddle/printing.py @@ -16,5 +16,6 @@ """Functions to output representations of `fdl.Buildable`s.""" # pylint: disable=unused-import +from fiddle._src.printing import as_dict_flattened from fiddle._src.printing import as_str_flattened from fiddle._src.printing import history_per_leaf_parameter