From 98d6e374056ee75f3f69099d0e644d835e003c42 Mon Sep 17 00:00:00 2001 From: Zhufeng Pan Date: Thu, 24 Aug 2023 10:48:54 -0700 Subject: [PATCH] Add initial support for positional args support within `fdl.Config`. To access positional args: ```python v = config[:] # the full list v = config[-1] # normal index v = config[:3] # slice index ``` To modify positional args: ```python config[:] = [1, 2] # assign to a new list config[1] = 2 ``` PiperOrigin-RevId: 559802927 --- fiddle/__init__.py | 1 + fiddle/_src/building.py | 42 +++- fiddle/_src/config.py | 277 ++++++++++++++++++++++--- fiddle/_src/config_test.py | 214 +++++++++++++++++-- fiddle/_src/signatures.py | 72 +++++-- fiddle/config.py | 1 + fiddle/examples/colabs/basic_api.ipynb | 45 +--- 7 files changed, 525 insertions(+), 127 deletions(-) diff --git a/fiddle/__init__.py b/fiddle/__init__.py index 33146324..0ab98d0f 100644 --- a/fiddle/__init__.py +++ b/fiddle/__init__.py @@ -26,6 +26,7 @@ from fiddle._src.config import NO_VALUE from fiddle._src.config import ordered_arguments from fiddle._src.config import update_callable +from fiddle._src.config import VARARGS from fiddle._src.materialize import materialize_defaults from fiddle._src.partial import ArgFactory from fiddle._src.partial import Partial diff --git a/fiddle/_src/building.py b/fiddle/_src/building.py index 6212a6fb..9324bdd4 100644 --- a/fiddle/_src/building.py +++ b/fiddle/_src/building.py @@ -19,7 +19,7 @@ import functools import logging import threading -from typing import Any, Callable, Dict, TypeVar, overload +from typing import Any, Callable, Dict, Sequence, TypeVar, overload from fiddle._src import config as config_lib from fiddle._src import daglish @@ -60,8 +60,12 @@ def _format_arg(arg: Any) -> str: return f'' -def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable, - arguments: Dict[str, Any]) -> str: +def _make_message( + current_path: daglish.Path, + buildable: config_lib.Buildable, + args: Sequence[Any], + kwargs: Dict[str, Any], +) -> str: """Returns Fiddle-related debugging information for an exception.""" path_str = '' + daglish.path_str(current_path) fn_or_cls = config_lib.get_callable(buildable) @@ -69,11 +73,15 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable, fn_or_cls_name = fn_or_cls.__qualname__ except AttributeError: fn_or_cls_name = str(fn_or_cls) # callable instances, etc. + args_str = ', '.join(f'{_format_arg(value)}' for value in args) kwargs_str = ', '.join( - f'{name}={_format_arg(value)}' for name, value in arguments.items()) + f'{name}={_format_arg(value)}' for name, value in kwargs.items() + ) tag_information = '' - bound_args = buildable.__signature_info__.signature.bind_partial(**arguments) + bound_args = buildable.__signature_info__.signature.bind_partial( + *args, **kwargs + ) bound_args.apply_defaults() unset_arg_tags = [] for param in buildable.__signature_info__.parameters: @@ -90,7 +98,8 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable, return ( '\n\nFiddle context: failed to construct or call ' f'{fn_or_cls_name} at {path_str} ' - f'with arguments ({kwargs_str}){tag_information}' + f'with positional arguments: ({args_str}), ' + f'keyword arguments: ({kwargs_str}){tag_information}.' ) @@ -100,10 +109,25 @@ def call_buildable( *, current_path: daglish.Path, ) -> Any: - make_message = functools.partial(_make_message, current_path, buildable, - arguments) + """Prepare positional arguments and actually build the buildable.""" + positional_only, keyword_or_positional, var_positional = ( + buildable.__signature_info__.get_positional_names() + ) + positional_arguments = [] + for name in positional_only: + if name in arguments: + positional_arguments.append(arguments.pop(name)) + if var_positional is not None: + for name in keyword_or_positional: + if name in arguments: + positional_arguments.append(arguments.pop(name)) + if var_positional in arguments: + positional_arguments.extend(arguments.pop(var_positional)) + make_message = functools.partial( + _make_message, current_path, buildable, positional_arguments, arguments + ) with reraised_exception.try_with_lazy_message(make_message): - return buildable.__build__(**arguments) + return buildable.__build__(*positional_arguments, **arguments) # Define typing overload for `build(Partial[T])` diff --git a/fiddle/_src/config.py b/fiddle/_src/config.py index 44d2b6d6..166b4a29 100644 --- a/fiddle/_src/config.py +++ b/fiddle/_src/config.py @@ -22,6 +22,7 @@ import copy import dataclasses import functools +import inspect import types from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union @@ -58,6 +59,9 @@ def __copy__(self): # None or other commonly-used sentinel. _UNSET_SENTINEL = object() +# Unique object instance that represents the index where varadic positional +# arguments start for a Buildable. +VARARGS = object() _defaults_aware_traverser_registry = daglish.NodeTraverserRegistry( use_fallback=True @@ -247,7 +251,7 @@ def __init__( ) for name, value in arguments.items(): - setattr(self, name, value) + self._setattr(name, value, allow_postional_argument=True) for name, tags in tag_type.find_tags_from_annotations(fn_or_cls).items(): self.__argument_tags__[name].update(tags) @@ -258,6 +262,7 @@ def __init__( def __init_callable__( self, fn_or_cls: Union['Buildable[T]', TypeOrCallableProducingT[T]] ) -> None: + """Save information on `fn_or_cls` to the `Buildable`.""" if isinstance(fn_or_cls, Buildable): raise ValueError( 'Using the Buildable constructor to convert a buildable to a new ' @@ -273,9 +278,11 @@ def __init_callable__( super().__setattr__('__fn_or_cls__', fn_or_cls) super().__setattr__('__arguments__', {}) signature = signatures.get_signature(fn_or_cls) + # Several attributes are computed automatically by SignatureInfo during + # `__post_init__`. super().__setattr__( '__signature_info__', - signatures.SignatureInfo(signature), + signatures.SignatureInfo(signature=signature), ) def __init_subclass__(cls): @@ -312,6 +319,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]: def __getattr__(self, name: str): """Get parameter with given ``name``.""" value = self.__arguments__.get(name, _UNSET_SENTINEL) + param = self.__signature_info__.parameters.get(name) + if param is not None and ( + param.kind in (param.POSITIONAL_ONLY, param.VAR_POSITIONAL) + ): + raise AttributeError( + 'Cannot access positional-only or variadic positional arguments ' + f'{name} on {self!r} by attributes.' + ) if value is not _UNSET_SENTINEL: return value @@ -323,7 +338,6 @@ def __getattr__(self, name: str): + f'{self.__fn_or_cls__.__qualname__}.{name} ' + 'since it uses a default_factory.' ) - param = self.__signature_info__.parameters.get(name) if param is not None and param.default is not param.empty: return param.default msg = f"No parameter '{name}' has been set on {self!r}." @@ -340,10 +354,15 @@ def __getattr__(self, name: str): ) raise AttributeError(msg) - def __setattr__(self, name: str, value: Any): - """Sets parameter ``name`` to ``value``.""" - - self.__signature_info__.validate_param_name(name, self.__fn_or_cls__) + def _setattr( + self, name: str, value: Any, allow_postional_argument: bool = False + ): + """The __setattr__ implementation.""" + self.__signature_info__.validate_param_name( + name, + self.__fn_or_cls__, + allow_postional_argument=allow_postional_argument, + ) if isinstance(value, TaggedValueCls): tags = value.__argument_tags__.get('value', ()) @@ -360,6 +379,10 @@ def __setattr__(self, name: str, value: Any): self.__arguments__[name] = value self.__argument_history__.add_new_value(name, value) + def __setattr__(self, name: str, value: Any): + """Sets parameter ``name`` to ``value``.""" + self._setattr(name, value) + def __delattr__(self, name): """Unsets parameter ``name``.""" try: @@ -369,6 +392,90 @@ def __delattr__(self, name): err = AttributeError(f"No parameter '{name}' has been set on {self!r}") raise err from None + def _get_all_positional_args(self): + """Get a full list of positional arguments.""" + positional_only, keyword_or_positional, var_positional = ( + self.__signature_info__.get_positional_names() + ) + positional_arguments = [] + for name in positional_only: + positional_arguments.append(self.__arguments__[name]) + if var_positional: + for name in keyword_or_positional: + positional_arguments.append(self.__arguments__[name]) + positional_arguments += self.__arguments__.get(var_positional, []) + return positional_arguments + + def _replace_varargs_handle(self, key): + """Replace VARARGS handle in index key if exists.""" + positional_only, keyword_or_positional, _ = ( + self.__signature_info__.get_positional_names() + ) + start = len(positional_only) + len(keyword_or_positional) + if isinstance(key, slice) and key.start is VARARGS: + return slice(start, key.stop, key.step) + elif key is VARARGS: + return start + return key + + def __getitem__(self, key: Any): + """Get positional arguments by index.""" + key = self._replace_varargs_handle(key) + all_positional_args = self._get_all_positional_args() + return all_positional_args[key] + + def __setitem__(self, key: Any, value: Any): + """Set positional arguments by index.""" + key = self._replace_varargs_handle(key) + assert isinstance( + key, (int, slice) + ), f'Key of __setitem__ must be an int or slice, got {key}.' + positional_only, keyword_or_positional, var_positional = ( + self.__signature_info__.get_positional_names() + ) + positional_names = positional_only + if var_positional: + positional_names += keyword_or_positional + old_positional_args = self._get_all_positional_args() + # Set positional arguments values using a comparison approach. + # Because setting values directly will lead to very complex logics due to + # various indices patterns, as well as the case where key is a slice but + # the value is not a sequence object. + new_positional_args = copy.deepcopy(old_positional_args) + new_positional_args[key] = value + + # Handle non-variadic positional arguments + for index, name in enumerate(positional_names): + if index < len(new_positional_args): + if old_positional_args[index] != new_positional_args[index]: + new_value = new_positional_args[index] + self.__arguments__[name] = new_value + self.__argument_history__.add_new_value(name, new_value) + else: + del self.__arguments__[name] + self.__argument_history__.add_deleted_value(name) + + # Handle variadic positional arguments + if var_positional is None: + if len(new_positional_args) > len(positional_names): + raise ValueError( + 'Too many arguments are provided. There are only ' + f'{len(positional_names)} positional arguments but ' + f'{len(new_positional_args)} are provided to ' + f'{self.__fn_or_cls__.__qualname__}.' + ) + else: + if len(new_positional_args) <= len(positional_names): + del self.__arguments__[var_positional] + self.__argument_history__.add_deleted_value(var_positional) + else: + new_var_positional_arg = new_positional_args[len(positional_names) :] + if new_var_positional_arg != self.__arguments__.get(var_positional, []): + self.__arguments__[var_positional] = new_var_positional_arg + self.__argument_history__.add_new_value( + var_positional, new_var_positional_arg + ) + def __dir__(self) -> Collection[str]: """Provide a useful list of attribute names, optimized for Jupyter/Colab. @@ -488,9 +595,7 @@ def __getstate__(self): Dict of serialized state. """ result = dict(self.__dict__) - result['__signature_info__'] = signatures.SignatureInfo( # pytype: disable=wrong-arg-types - None, result['__signature_info__'].has_var_keyword - ) + result['__signature_info__'] = signatures.SignatureInfo(None) # pytype: disable=wrong-arg-types return result def __setstate__(self, state) -> None: @@ -503,8 +608,10 @@ def __setstate__(self, state) -> None: """ self.__dict__.update(state) # Support unpickle. if self.__signature_info__.signature is None: - self.__signature_info__.signature = signatures.get_signature( - self.__fn_or_cls__ + signature = signatures.get_signature(self.__fn_or_cls__) + super().__setattr__( + '__signature_info__', + signatures.SignatureInfo(signature=signature), ) @@ -637,11 +744,123 @@ def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str): return False +def _update_positional_args( + buildable: Buildable, + original_signature: inspect.Signature, + new_signature: inspect.Signature, + drop_invalid_args: bool = False, +) -> None: + """Update positional arguments in place. + + The naive approach to update positional arguments when changing the callable + is to update each individual argument. However, the mapping problem (for + example, some positional arugments may have differnt names now, or need to map + *args to concrete positional arguments) is very challenging, because there + are mulitple possible conditions depending on if the origial and the new + callable have varadic positional arugments. + + This method adopts the approach that first builds a full list of all + positional arguments, and then try to map the argument list accroding to the + signature of new callable. + + Args: + buildable: A ``Buildable`` (e.g. a ``fdl.Config``) to update. + original_signature: Signature of the original callable. + new_signature: Signature of the new callable. + drop_invalid_args: If True, arguments that don't exist in the new callable + will be removed from buildable. If False, raise an exception for such + arguments. + + Raises: + TypeError: If fails to match the origial positional arguments to the new + callable. + """ + positional_argument_names = [] + positional_argument_values = [] + keyword_or_positional_argument_names = [] + keyword_or_positional_argument_values = [] + var_positional_name = None + + for param in original_signature.parameters.values(): + if param.name not in buildable.__arguments__: + break + if param.kind == param.POSITIONAL_ONLY: + positional_argument_names.append(param.name) + value = buildable.__arguments__[param.name] + positional_argument_values.append(value) + if param.kind == param.POSITIONAL_OR_KEYWORD: + keyword_or_positional_argument_names.append(param.name) + value = buildable.__arguments__[param.name] + keyword_or_positional_argument_values.append(value) + if param.kind == param.VAR_POSITIONAL: + var_positional_name = param.name + values = buildable.__arguments__[param.name] + if values: + # if *args exist, keyword-or-positional arguments will become + # positional arguments. + positional_argument_names.extend(keyword_or_positional_argument_names) + positional_argument_values.extend(keyword_or_positional_argument_values) + positional_argument_names.extend([None for _ in values]) + positional_argument_values.extend(values) + break + + for index, param in enumerate(new_signature.parameters.values()): + if index >= len(positional_argument_values): + break + new_name = param.name + old_name = positional_argument_names[index] + if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): + new_value = positional_argument_values[index] + if old_name == new_name and ( + buildable.__arguments__[old_name] == new_value + ): + continue + buildable.__arguments__[new_name] = new_value + buildable.__argument_history__.add_new_value(new_name, new_value) + if old_name != new_name and old_name is not None: + del buildable.__arguments__[old_name] + buildable.__argument_history__.add_deleted_value(old_name) + if param.kind == param.VAR_POSITIONAL: + # All positional arguments will be matched to *args, so delte all + # remaining positional-only and positional-or-keyword arguments in the + # current `__arguments__` dict. + new_value = positional_argument_values[index:] + for name in positional_argument_names[index:]: + if name and name in buildable.__arguments__: + del buildable.__arguments__[name] + buildable.__argument_history__.add_deleted_value(name) + if var_positional_name and var_positional_name in buildable.__arguments__: + if new_name != var_positional_name: + del buildable.__arguments__[var_positional_name] + buildable.__argument_history__.add_deleted_value(var_positional_name) + else: + # Varadic positional arguments have the same name and value + if new_value == buildable.__arguments__[var_positional_name]: + break + buildable.__arguments__[new_name] = new_value + buildable.__argument_history__.add_new_value(param.kind, new_value) + # All positional arguments have been matched, exit the for loop. + break + if param.kind in (param.KEYWORD_ONLY, param.VAR_KEYWORD): + if drop_invalid_args: + raise NotImplementedError( + 'Drop invalid positional arguments are not supported yet.' + ) + else: + raise TypeError( + f'Fail to match buildable {buildable} from signature' + f'{original_signature} to {new_signature}.' + ) + if var_positional_name in buildable.__arguments__: + del buildable.__arguments__[var_positional_name] + buildable.__argument_history__.add_deleted_value(var_positional_name) + + def update_callable( buildable: Buildable, new_callable: TypeOrCallableProducingT, drop_invalid_args: bool = False, -): +) -> None: """Updates ``config`` to build ``new_callable`` instead. When extending a base configuration, it can often be useful to swap one class @@ -666,24 +885,26 @@ def update_callable( # # Note: can't call `setattr` on all the args to validate them, because that # will result in duplicate history entries. - original_args = buildable.__arguments__ - signature = signatures.get_signature(new_callable) - if any( - param.kind == param.VAR_POSITIONAL - for param in signature.parameters.values() - ): - raise NotImplementedError( - 'Variable positional arguments (aka `*args`) not supported.' - ) - signature_info = signatures.SignatureInfo(signature) - object.__setattr__( + new_signature = signatures.get_signature(new_callable) + # Update the signature early so that we can set arguments by position. + # Otherwise, parameter validation logics would complain about argument + # name not exists. + object.__setattr__(buildable, '__signature__', new_signature) + new_signature_info = signatures.SignatureInfo(signature=new_signature) + original_signature_info = buildable.__signature_info__ + object.__setattr__(buildable, '__signature_info__', new_signature_info) + _update_positional_args( buildable, - '__signature_info__', - signature_info, + original_signature_info.signature, + new_signature_info.signature, + drop_invalid_args, ) - if not signature_info.has_var_keyword: + + if not new_signature_info.has_var_keyword: invalid_args = [ - arg for arg in original_args.keys() if arg not in signature.parameters + arg + for arg in buildable.__arguments__.keys() + if arg not in new_signature.parameters ] if invalid_args: if drop_invalid_args: diff --git a/fiddle/_src/config_test.py b/fiddle/_src/config_test.py index d03e21da..39e3e5cb 100644 --- a/fiddle/_src/config_test.py +++ b/fiddle/_src/config_test.py @@ -72,6 +72,14 @@ def fn_with_var_args_and_kwargs(arg1, *args, kwarg1=None, **kwargs): # pylint: return locals() +def fn_with_args_and_kwargs_only(*args, **kwargs): + return args, kwargs + + +def fn_with_position_args(a, b, /, c=1, *args): # pylint: disable=keyword-arg-before-vararg, unused-argument + return locals() + + def make_typed_config() -> fdl.Config[SampleClass]: """Helper function which returns a fdl.Config whose type is known.""" return fdl.Config(SampleClass, arg1=1, arg2=2) @@ -195,14 +203,109 @@ def test_config_for_functions_with_var_args_and_kwargs(self): fn_args = fdl.build(fn_config) self.assertEqual(fn_args['arg1'], 'arg1') - fn_config.args = 'kwarg_called_arg' fn_config.kwargs = 'kwarg_called_kwarg' fn_args = fdl.build(fn_config) self.assertEqual(fn_args['kwargs'], { - 'args': 'kwarg_called_arg', 'kwargs': 'kwarg_called_kwarg' }) + def test_postional_args_access(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + self.assertEqual(fn_config[0], 1) + self.assertEqual(fn_config[-1], 5) + self.assertSequenceEqual(fn_config[3:], [4, 5]) + self.assertSequenceEqual(fn_config[:], [1, 2, 3, 4, 5]) + + def test_positional_args_modification(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + fn_config[0] = 0 + self.assertSequenceEqual(fn_config[:], [0, 2, 3, 4, 5]) + fn_config[:3] = [5, 6, 7] + self.assertSequenceEqual(fn_config[:], [5, 6, 7, 4, 5]) + fn_config[3:] = [5, 6, 7] + self.assertSequenceEqual(fn_config[:], [5, 6, 7, 5, 6, 7]) + fn_config[:] = [1, 2, 3] + self.assertSequenceEqual(fn_config[:], [1, 2, 3]) + fn_config[:] += [4, 5] + self.assertSequenceEqual(fn_config[:], [1, 2, 3, 4, 5]) + + def test_varargs_index_handle(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + with self.subTest('access'): + self.assertEqual(fn_config[fdl.VARARGS], 4) + self.assertSequenceEqual(fn_config[fdl.VARARGS :], [4, 5]) + with self.subTest('modify'): + fn_config[fdl.VARARGS :] = [] + self.assertSequenceEqual(fn_config[:], [1, 2, 3]) + fn_config[fdl.VARARGS :] = [7, 8, 9] + self.assertSequenceEqual(fn_config[:], [1, 2, 3, 7, 8, 9]) + fn_config[fdl.VARARGS] = 0 + self.assertSequenceEqual(fn_config[:], [1, 2, 3, 0, 8, 9]) + + def test_modification_when_var_args_are_empty(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3) + self.assertEmpty(fn_config[fdl.VARARGS :]) + fn_config[:] = ['a', 'b', 'c', 'd', 'e'] + self.assertSequenceEqual(fn_config[:], ['a', 'b', 'c', 'd', 'e']) + + def test_positional_args_direct_access_is_forbidden(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + with self.assertRaisesRegex( + AttributeError, + 'Cannot access positional-only or variadic positional arguments', + ): + _ = fn_config.args + + with self.assertRaisesRegex( + AttributeError, + 'Cannot access positional-only or variadic positional arguments', + ): + _ = fn_config.a + + def test_positional_args_direct_modification_is_forbidden(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + with self.assertRaisesRegex( + AttributeError, 'Cannot access VAR_POSITIONAL parameter' + ): + fn_config.args = [0] + + with self.assertRaisesRegex( + AttributeError, 'Cannot access POSITIONAL_ONLY parameter' + ): + fn_config.a = 0 + + def test_positional_or_keyword_args_have_consistent_values(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + fn_config[2] = 'arg-c' + self.assertEqual(fn_config.c, 'arg-c') + fn_config.c = 'index-2' + self.assertEqual(fn_config[2], 'index-2') + + def test_index_out_of_range(self): + fn_config = fdl.Config(fn_with_var_args, 1, 2) + self.assertLen(fn_config[:], 2) + with self.assertRaisesRegex( + IndexError, 'list assignment index out of range' + ): + fn_config[2] = 'index-2' + with self.assertRaisesRegex(IndexError, 'list index out of range'): + _ = fn_config[2] + + def test_args_config_shallow_copy(self): + fn_config = fdl.Config(fn_with_var_args, 1, 2) + self.assertLen(fn_config[:], 2) + a_copy = fn_config[:] + a_copy.append('3') + self.assertLen(fn_config[:], 2) + self.assertLen(a_copy, 3) + + def test_args_config_build(self): + fn_config = fdl.Config(fn_with_position_args, 1, 2, 3, 4, 5) + self.assertEqual( + fdl.build(fn_config), + {'a': 1, 'b': 2, 'c': 3, 'args': (4, 5)}, + ) + def test_config_for_dicts(self): dict_config = fdl.Config(dict, a=1, b=2) dict_config.c = 3 @@ -616,18 +719,6 @@ def test_nonexistent_parameter_error(self): with self.assertRaisesRegex(TypeError, expected_msg): class_config.nonexistent_arg = 'error!' - def test_nonexistent_var_args_parameter_error(self): - fn_config = fdl.Config(fn_with_var_args) - expected_msg = (r'Variadic arguments \(e.g. \*args\) are not supported\.') - with self.assertRaisesRegex(TypeError, expected_msg): - fn_config.args = (1, 2, 3) - - def test_unsupported_var_args_error(self): - expected_msg = (r'Variable positional arguments \(aka `\*args`\) not ' - r'supported\.') - with self.assertRaisesRegex(NotImplementedError, expected_msg): - fdl.Config(fn_with_var_args, 1, 2, 3) - def test_build_inside_build(self): def inner_build(x: int) -> str: @@ -737,15 +828,21 @@ def test_build_nested_structure(self): def test_build_raises_nice_error_too_few_args(self): cfg = fdl.Config(basic_fn, fdl.Config(SampleClass, 1), 2) - with self.assertRaisesRegex( - TypeError, r'.*missing 1 required.*\n\n.*\.arg1.*arg1=1'): + with self.assertRaises(TypeError) as e: fdl.build(cfg) + self.assertEqual( + e.exception.proxy_message, # pytype: disable=attribute-error + '\n\nFiddle context: failed to construct or call SampleClass at ' + '.arg1 with positional arguments: (), keyword arguments: ' + '(arg1=1).', + ) def test_build_raises_exception_on_call(self): cfg = fdl.Config(raise_error) msg = ( 'My fancy exception\n\nFiddle context: failed to construct or call ' - 'raise_error at with arguments ()' + 'raise_error at with positional arguments: (), ' + 'keyword arguments: ().' ) with self.assertRaisesWithLiteralMatch(ValueError, msg): fdl.build(cfg) @@ -762,7 +859,9 @@ def test_build_error_path(self): self.assertEqual( e.exception.proxy_message, # pytype: disable=attribute-error '\n\nFiddle context: failed to construct or call basic_fn at .' - "arg1[1]['c'] with arguments (arg1=1)") + "arg1[1]['c'] with positional arguments: (), " + 'keyword arguments: (arg1=1).', + ) def test_multithreaded_build(self): """Two threads can each invoke build.build without interfering.""" @@ -926,11 +1025,80 @@ def test_update_callable_new_kwargs(self): } }, fdl.build(cfg)) - def test_update_callable_varargs(self): - cfg = fdl.Config(fn_with_var_kwargs, 1, 2) - with self.assertRaisesRegex(NotImplementedError, - 'Variable positional arguments'): - fdl.update_callable(cfg, fn_with_var_args_and_kwargs) + # For `update_callable` involves variadic positional arguments, we test + # four patterns below. + # Pattern 1: *args -> *args + def test_original_and_new_callable_have_var_positaionl(self): + cfg = fdl.Config(fn_with_var_args, 1, 2, kwarg1=3) + fdl.update_callable(cfg, fn_with_var_args_and_kwargs) + self.assertEqual(cfg.__arguments__, {'arg1': 1, 'args': [2], 'kwarg1': 3}) + self.assertEqual( + fdl.build(cfg), + {'arg1': 1, 'args': (2,), 'kwarg1': 3, 'kwargs': {}}, + ) + + def test_original_var_args_are_empty(self): + def foo(a, b, c, /, d=0, *args): # pylint: disable=keyword-arg-before-vararg, unused-argument + return locals() + + def bar(a, /, b=0, *args): # pylint: disable=keyword-arg-before-vararg, unused-argument + return locals() + + cfg = fdl.Config(foo, 1, 2, 3) + fdl.update_callable(cfg, bar) + self.assertEqual(cfg.__arguments__, {'a': 1, 'b': 2, 'args': [3]}) + self.assertEqual( + fdl.build(cfg), + {'a': 1, 'b': 2, 'args': (3,)}, + ) + + def test_update_args_kwargs(self): + cfg = fdl.Config(fn_with_args_and_kwargs_only, 1, 2, 3, kwarg1=4, kwarg2=5) + cfg[0] = 10 + cfg.kwarg1 = 40 + config_lib.update_callable(cfg, fn_with_var_args_and_kwargs) + self.assertEqual( + cfg.__arguments__, + {'arg1': 10, 'args': [2, 3], 'kwarg1': 40, 'kwarg2': 5}, + ) + self.assertEqual( + fdl.build(cfg), + {'arg1': 10, 'args': (2, 3), 'kwarg1': 40, 'kwargs': {'kwarg2': 5}}, + ) + + # Pattern 2: *args -> no *args + def test_original_callable_has_var_positaionl(self): + def positional_fn(a, b, c, /, kwarg1): # pylint: disable=keyword-arg-before-vararg, unused-argument + return locals() + + cfg = fdl.Config(fn_with_var_args, 1, 2, 3, kwarg1=4) + fdl.update_callable(cfg, positional_fn) + cfg[1] = 22 + self.assertEqual(cfg.__arguments__, {'a': 1, 'b': 22, 'c': 3, 'kwarg1': 4}) + self.assertEqual(fdl.build(cfg), {'a': 1, 'b': 22, 'c': 3, 'kwarg1': 4}) + + # Pattern 3: no *args -> *args + def test_new_callable_has_var_positaionl(self): + cfg = fdl.Config(basic_fn, 1, 2, kwarg1=3) + fdl.update_callable(cfg, fn_with_var_args_and_kwargs) + self.assertEqual(cfg.__arguments__, {'arg1': 1, 'arg2': 2, 'kwarg1': 3}) + self.assertEqual( + fdl.build(cfg), + {'arg1': 1, 'args': (), 'kwarg1': 3, 'kwargs': {'arg2': 2}}, + ) + + # Pattern 4: no *args -> no *args + def test_no_var_positional_w_different_names(self): + def foo(x, y, /, z=0): # pylint: disable=keyword-arg-before-vararg, unused-argument + return locals() + + def bar(a, b, c=0, /, d='d'): # pylint: disable=keyword-arg-before-vararg, unused-argument + return locals() + + cfg = fdl.Config(foo, 1, 2) + fdl.update_callable(cfg, bar) + self.assertEqual(cfg.__arguments__, {'a': 1, 'b': 2}) + self.assertEqual(fdl.build(cfg), {'a': 1, 'b': 2, 'c': 0, 'd': 'd'}) def test_get_callable(self): cfg = fdl.Config(basic_fn) diff --git a/fiddle/_src/signatures.py b/fiddle/_src/signatures.py index f975bba2..def523e2 100644 --- a/fiddle/_src/signatures.py +++ b/fiddle/_src/signatures.py @@ -17,7 +17,7 @@ import dataclasses import inspect -from typing import Any, Callable, Dict, Generic, Mapping, Tuple, Type +from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type import weakref import typing_extensions @@ -135,15 +135,33 @@ class SignatureInfo: """To store signature related information about the callable.""" signature: inspect.Signature - has_var_keyword: bool = None + has_var_keyword: Optional[bool] = None + var_positional_name: Optional[str] = None + positional_arg_names: Optional[List[str]] = dataclasses.field( + default_factory=list + ) def __post_init__(self): - if self.has_var_keyword is None: - has_var_keyword = any( - param.kind == param.VAR_KEYWORD - for param in self.signature.parameters.values() - ) - self.has_var_keyword = has_var_keyword + # During serilization, signature is set to None so no action is needed. + if self.signature is None: + return + + # If *args exists, we must pass things before it in positional format. This + # list tracks those arguments. + maybe_positional_args = [] + positional_only_args = [] + for param in self.signature.parameters.values(): + if param.kind == param.POSITIONAL_ONLY: + positional_only_args.append(param.name) + elif param.kind == param.POSITIONAL_OR_KEYWORD: + maybe_positional_args.append(param.name) + elif param.kind == param.VAR_POSITIONAL: + positional_only_args.extend(maybe_positional_args) + if not self.var_positional_name: + self.var_positional_name = param.name + elif param.kind == param.VAR_KEYWORD and self.has_var_keyword is None: + self.has_var_keyword = True + self.positional_arg_names = positional_only_args @staticmethod def signature_binding(fn_or_cls, *args, **kwargs) -> Any: @@ -152,29 +170,37 @@ def signature_binding(fn_or_cls, *args, **kwargs) -> Any: arguments = signature.bind_partial(*args, **kwargs).arguments for name in list(arguments.keys()): # Make a copy in case we mutate. param = signature.parameters[name] - if param.kind == param.VAR_POSITIONAL: - # TODO(b/197367863): Add *args support. - err_msg = ( - 'Variable positional arguments (aka `*args`) not supported. ' - f'Found param `{name}` in `{fn_or_cls}`.' - ) - raise NotImplementedError(err_msg) - elif param.kind == param.VAR_KEYWORD: + if param.kind == param.VAR_KEYWORD: arguments.update(arguments.pop(param.name)) return arguments - def validate_param_name(self, name, fn_or_cls) -> None: + def get_positional_names(self) -> Tuple[List[str], List[str], str]: + """Get positional argument names.""" + positional_only = [] + keyword_or_positional = [] + for param in self.signature.parameters.values(): + if param.kind == param.POSITIONAL_ONLY: + positional_only.append(param.name) + elif param.kind == param.POSITIONAL_OR_KEYWORD: + keyword_or_positional.append(param.name) + return positional_only, keyword_or_positional, self.var_positional_name + + def validate_param_name( + self, name, fn_or_cls, allow_postional_argument=False + ) -> None: """Raises an error if ``name`` is not a valid parameter name.""" param = self.signature.parameters.get(name) if param is not None: - if param.kind == param.POSITIONAL_ONLY: - # TODO(b/197367863): Add positional-only arg support. - raise NotImplementedError( - 'Positional only arguments not supported. ' - f'Tried to set {name!r} on {fn_or_cls}' + if param.kind == param.POSITIONAL_ONLY and not allow_postional_argument: + raise AttributeError( + f'Cannot access POSITIONAL_ONLY parameter {name!r} on {fn_or_cls}' ) - elif param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD): + elif param.kind == param.VAR_POSITIONAL and not allow_postional_argument: + raise AttributeError( + f'Cannot access VAR_POSITIONAL parameter {name!r} on {fn_or_cls}' + ) + elif param.kind == param.VAR_KEYWORD: # Just pretend it doesn't correspond to a valid parameter name... below # a TypeError will be thrown unless there is a **kwargs parameter. param = None diff --git a/fiddle/config.py b/fiddle/config.py index 1735025b..cc68b606 100644 --- a/fiddle/config.py +++ b/fiddle/config.py @@ -22,5 +22,6 @@ from fiddle._src.config import Config from fiddle._src.config import NO_VALUE from fiddle._src.config import NoValue +from fiddle._src.config import VARARGS from fiddle._src.partial import ArgFactory from fiddle._src.partial import Partial diff --git a/fiddle/examples/colabs/basic_api.ipynb b/fiddle/examples/colabs/basic_api.ipynb index 9e13ad38..f783b8ee 100644 --- a/fiddle/examples/colabs/basic_api.ipynb +++ b/fiddle/examples/colabs/basic_api.ipynb @@ -401,50 +401,7 @@ "id": "G3IVzfktqAIu" }, "source": [ - "but `*args` are currently unsupported," - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "colab": { - "height": 34 - }, - "executionInfo": { - "elapsed": 4, - "status": "ok", - "timestamp": 1692835092749, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "p-r9vED0qNib", - "outputId": "603b56d2-862b-4962-fc5c-10f03547aa75" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\u003cspan style=\"color: red\"\u003eNotImplementedError: Variable positional arguments (aka `*args`) not supported. Found param `args` in `\u003cfunction args_and_kwargs at 0x7f4ba5bd4670\u003e`.\u003c/span\u003e" - ], - "text/plain": [ - "\u003cIPython.core.display.HTML object\u003e" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "try:\n", - " fdl.Config(args_and_kwargs, 4, 7)\n", - "except NotImplementedError as e:\n", - " display(HTML(f'\u003cspan style=\"color: red\"\u003eNotImplementedError: {e}\u003c/span\u003e'))\n", - "else:\n", - " raise AssertionError(\"This should raise an error!\")" + "# TODO(b/288893692): Update docs for posistional args." ] }, {