diff --git a/torchx/specs/builders.py b/torchx/specs/builders.py index 88e4b85f3..4f7c3af25 100644 --- a/torchx/specs/builders.py +++ b/torchx/specs/builders.py @@ -10,7 +10,7 @@ import inspect import os from argparse import Namespace -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union from torchx.specs.api import BindMount, MountType, VolumeMount from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter @@ -19,6 +19,14 @@ from .api import AppDef, DeviceMount +class ComponentArgs(NamedTuple): + """Parsed component function arguments""" + + positional_args: dict[str, Any] + var_args: list[str] + kwargs: dict[str, Any] + + def _create_args_parser( cmpnt_fn: Callable[..., AppDef], cmpnt_defaults: Optional[Dict[str, str]] = None, @@ -140,6 +148,91 @@ def parse_args( return parsed_args +def component_args_from_str( + cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type + cmpnt_args: list[str], + cmpnt_args_defaults: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, +) -> ComponentArgs: + """ + Parses and decodes command-line arguments for a component function. + + This function takes a component function and its arguments, parses them using argparse, + and decodes the arguments into their expected types based on the function's signature. + It separates positional arguments, variable positional arguments (*args), and keyword-only arguments. + + Args: + cmpnt_fn: The component function whose arguments are to be parsed and decoded. + cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments. + cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters. + config: Optional dictionary containing additional configuration values. + + Returns: + ComponentArgs representing the input args to a component function containing: + - positional_args: Dictionary of positional and positional-or-keyword arguments. + - var_args: List of variable positional arguments (*args). + - kwargs: Dictionary of keyword-only arguments. + + Usage: + + .. doctest:: + from torchx.specs.api import AppDef + from torchx.specs.builders import component_args_from_str + + def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef: + return AppDef(name="example") + + # Supports space separated arguments + args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"] + parsed_args = component_args_from_str(example_component_fn, args) + + assert parsed_args.positional_args == {"foo": "fooval"} + assert parsed_args.var_args == ["arg1", "arg2"] + assert parsed_args.kwargs == {"bar": "barval"} + + # Supports '=' separated arguments + args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"] + parsed_args = component_args_from_str(example_component_fn, args) + + assert parsed_args.positional_args == {"foo": "fooval"} + assert parsed_args.var_args == ["arg1", "arg2"] + assert parsed_args.kwargs == {"bar": "barval"} + + + """ + parsed_args: Namespace = parse_args( + cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config + ) + + positional_args = {} + var_args = [] + kwargs = {} + + parameters = inspect.signature(cmpnt_fn).parameters + for param_name, parameter in parameters.items(): + arg_value = getattr(parsed_args, param_name) + parameter_type = parameter.annotation + parameter_type = decode_optional(parameter_type) + arg_value = decode(arg_value, parameter_type) + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + var_args = arg_value + elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: + kwargs[param_name] = arg_value + elif parameter.kind == inspect.Parameter.VAR_KEYWORD: + raise TypeError( + f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the " + f"type to a dict or explicitly declare the params" + ) + else: + # POSITIONAL or POSITIONAL_OR_KEYWORD + positional_args[param_name] = arg_value + + if len(var_args) > 0 and var_args[0] == "--": + var_args = var_args[1:] + + return ComponentArgs(positional_args, var_args, kwargs) + + def materialize_appdef( cmpnt_fn: Callable[..., Any], # pyre-ignore[2] cmpnt_args: List[str], @@ -174,30 +267,14 @@ def materialize_appdef( An application spec """ - function_args = [] - var_arg = [] - kwargs = {} - - parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config) - - parameters = inspect.signature(cmpnt_fn).parameters - for param_name, parameter in parameters.items(): - arg_value = getattr(parsed_args, param_name) - parameter_type = parameter.annotation - parameter_type = decode_optional(parameter_type) - arg_value = decode(arg_value, parameter_type) - if parameter.kind == inspect.Parameter.VAR_POSITIONAL: - var_arg = arg_value - elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: - kwargs[param_name] = arg_value - elif parameter.kind == inspect.Parameter.VAR_KEYWORD: - raise TypeError("**kwargs are not supported for component definitions") - else: - function_args.append(arg_value) - if len(var_arg) > 0 and var_arg[0] == "--": - var_arg = var_arg[1:] + component_args: ComponentArgs = component_args_from_str( + cmpnt_fn, cmpnt_args, cmpnt_defaults, config + ) + positional_arg_values = list(component_args.positional_args.values()) + appdef = cmpnt_fn( + *positional_arg_values, *component_args.var_args, **component_args.kwargs + ) - appdef = cmpnt_fn(*function_args, *var_arg, **kwargs) if not isinstance(appdef, AppDef): raise TypeError( f"Expected a component that returns `AppDef`, but got `{type(appdef)}`" diff --git a/torchx/specs/test/builders_test.py b/torchx/specs/test/builders_test.py index a733f5502..e3ec31f5d 100644 --- a/torchx/specs/test/builders_test.py +++ b/torchx/specs/test/builders_test.py @@ -18,6 +18,8 @@ from torchx.specs.builders import ( _create_args_parser, BindMount, + component_args_from_str, + ComponentArgs, DeviceMount, make_app_handle, materialize_appdef, @@ -281,6 +283,36 @@ def _get_app_args_and_defaults_with_nested_objects( *role_args, ], defaults + def test_component_args_from_str(self) -> None: + component_fn_args = [ + "--foo", + "fooval", + "--bar", + "barval", + "arg1", + "arg2", + ] + parsed_args: ComponentArgs = component_args_from_str( + example_var_args, component_fn_args + ) + self.assertEqual(parsed_args.positional_args, {"foo": "fooval"}) + self.assertEqual(parsed_args.var_args, ["arg1", "arg2"]) + self.assertEqual(parsed_args.kwargs, {"bar": "barval"}) + + def test_component_args_from_str_equals_separated(self) -> None: + component_fn_args = [ + "--foo=fooval", + "--bar=barval", + "arg1", + "arg2", + ] + parsed_args: ComponentArgs = component_args_from_str( + example_var_args, component_fn_args + ) + self.assertEqual(parsed_args.positional_args, {"foo": "fooval"}) + self.assertEqual(parsed_args.var_args, ["arg1", "arg2"]) + self.assertEqual(parsed_args.kwargs, {"bar": "barval"}) + def test_load_from_fn_empty(self) -> None: actual_app = materialize_appdef(example_empty_fn, []) expected_app = get_dummy_application("trainer")