Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 101 additions & 24 deletions torchx/specs/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import inspect
import os
from argparse import Namespace
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union

from torchx.specs.api import BindMount, MountType, VolumeMount
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
Expand All @@ -19,6 +19,14 @@
from .api import AppDef, DeviceMount


class ComponentArgs(NamedTuple):
"""Parsed component function arguments"""

positional_args: dict[str, Any]
var_args: list[str]
kwargs: dict[str, Any]


def _create_args_parser(
cmpnt_fn: Callable[..., AppDef],
cmpnt_defaults: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -140,6 +148,91 @@ def parse_args(
return parsed_args


def component_args_from_str(
cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type
cmpnt_args: list[str],
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
) -> ComponentArgs:
"""
Parses and decodes command-line arguments for a component function.

This function takes a component function and its arguments, parses them using argparse,
and decodes the arguments into their expected types based on the function's signature.
It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.

Args:
cmpnt_fn: The component function whose arguments are to be parsed and decoded.
cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
config: Optional dictionary containing additional configuration values.

Returns:
ComponentArgs representing the input args to a component function containing:
- positional_args: Dictionary of positional and positional-or-keyword arguments.
- var_args: List of variable positional arguments (*args).
- kwargs: Dictionary of keyword-only arguments.

Usage:

.. doctest::
from torchx.specs.api import AppDef
from torchx.specs.builders import component_args_from_str

def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
return AppDef(name="example")

# Supports space separated arguments
args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
parsed_args = component_args_from_str(example_component_fn, args)

assert parsed_args.positional_args == {"foo": "fooval"}
assert parsed_args.var_args == ["arg1", "arg2"]
assert parsed_args.kwargs == {"bar": "barval"}

# Supports '=' separated arguments
args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
parsed_args = component_args_from_str(example_component_fn, args)

assert parsed_args.positional_args == {"foo": "fooval"}
assert parsed_args.var_args == ["arg1", "arg2"]
assert parsed_args.kwargs == {"bar": "barval"}


"""
parsed_args: Namespace = parse_args(
cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
)

positional_args = {}
var_args = []
kwargs = {}

parameters = inspect.signature(cmpnt_fn).parameters
for param_name, parameter in parameters.items():
arg_value = getattr(parsed_args, param_name)
parameter_type = parameter.annotation
parameter_type = decode_optional(parameter_type)
arg_value = decode(arg_value, parameter_type)
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
var_args = arg_value
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[param_name] = arg_value
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
raise TypeError(
f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
f"type to a dict or explicitly declare the params"
)
else:
# POSITIONAL or POSITIONAL_OR_KEYWORD
positional_args[param_name] = arg_value

if len(var_args) > 0 and var_args[0] == "--":
var_args = var_args[1:]

return ComponentArgs(positional_args, var_args, kwargs)


def materialize_appdef(
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
cmpnt_args: List[str],
Expand Down Expand Up @@ -174,30 +267,14 @@ def materialize_appdef(
An application spec
"""

function_args = []
var_arg = []
kwargs = {}

parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)

parameters = inspect.signature(cmpnt_fn).parameters
for param_name, parameter in parameters.items():
arg_value = getattr(parsed_args, param_name)
parameter_type = parameter.annotation
parameter_type = decode_optional(parameter_type)
arg_value = decode(arg_value, parameter_type)
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg = arg_value
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[param_name] = arg_value
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
raise TypeError("**kwargs are not supported for component definitions")
else:
function_args.append(arg_value)
if len(var_arg) > 0 and var_arg[0] == "--":
var_arg = var_arg[1:]
component_args: ComponentArgs = component_args_from_str(
cmpnt_fn, cmpnt_args, cmpnt_defaults, config
)
positional_arg_values = list(component_args.positional_args.values())
appdef = cmpnt_fn(
*positional_arg_values, *component_args.var_args, **component_args.kwargs
)

appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
if not isinstance(appdef, AppDef):
raise TypeError(
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
Expand Down
32 changes: 32 additions & 0 deletions torchx/specs/test/builders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from torchx.specs.builders import (
_create_args_parser,
BindMount,
component_args_from_str,
ComponentArgs,
DeviceMount,
make_app_handle,
materialize_appdef,
Expand Down Expand Up @@ -281,6 +283,36 @@ def _get_app_args_and_defaults_with_nested_objects(
*role_args,
], defaults

def test_component_args_from_str(self) -> None:
component_fn_args = [
"--foo",
"fooval",
"--bar",
"barval",
"arg1",
"arg2",
]
parsed_args: ComponentArgs = component_args_from_str(
example_var_args, component_fn_args
)
self.assertEqual(parsed_args.positional_args, {"foo": "fooval"})
self.assertEqual(parsed_args.var_args, ["arg1", "arg2"])
self.assertEqual(parsed_args.kwargs, {"bar": "barval"})

def test_component_args_from_str_equals_separated(self) -> None:
component_fn_args = [
"--foo=fooval",
"--bar=barval",
"arg1",
"arg2",
]
parsed_args: ComponentArgs = component_args_from_str(
example_var_args, component_fn_args
)
self.assertEqual(parsed_args.positional_args, {"foo": "fooval"})
self.assertEqual(parsed_args.var_args, ["arg1", "arg2"])
self.assertEqual(parsed_args.kwargs, {"bar": "barval"})

def test_load_from_fn_empty(self) -> None:
actual_app = materialize_appdef(example_empty_fn, [])
expected_app = get_dummy_application("trainer")
Expand Down
Loading