Skip to content

Commit 1d0e9f6

Browse files
ethanbwaitefacebook-github-bot
authored andcommitted
(torchx/specs) Add intermediate helper method to parse component arguments (#1097)
Summary: Extracts logic from `materialize_appdef` to new helper method to standardize a way to parse component function arguments Differential Revision: D80215193
1 parent 0255f71 commit 1d0e9f6

File tree

1 file changed

+62
-24
lines changed

1 file changed

+62
-24
lines changed

torchx/specs/builders.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import inspect
1111
import os
1212
from argparse import Namespace
13-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
13+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1414

1515
from torchx.specs.api import BindMount, MountType, VolumeMount
1616
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
@@ -140,6 +140,62 @@ def parse_args(
140140
return parsed_args
141141

142142

143+
def parse_and_decode_args(
144+
cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type
145+
cmpnt_args: list[str],
146+
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
147+
config: Optional[Dict[str, Any]] = None,
148+
) -> Tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
149+
"""
150+
Parses and decodes command-line arguments for a component function.
151+
152+
This function takes a component function and its arguments, parses them using argparse,
153+
and decodes the arguments into their expected types based on the function's signature.
154+
It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
155+
156+
Args:
157+
cmpnt_fn: The component function whose arguments are to be parsed and decoded.
158+
cmpnt_args: List of command-line arguments to be parsed.
159+
cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
160+
config: Optional dictionary containing additional configuration values.
161+
162+
Returns:
163+
A tuple containing:
164+
- positional_args: Dictionary of positional and positional-or-keyword arguments.
165+
- var_args: List of variable positional arguments (*args).
166+
- kwargs: Dictionary of keyword-only arguments.
167+
"""
168+
parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config)
169+
170+
positional_args = {}
171+
var_args = []
172+
kwargs = {}
173+
174+
parameters = inspect.signature(cmpnt_fn).parameters
175+
for param_name, parameter in parameters.items():
176+
arg_value = getattr(parsed_args, param_name)
177+
parameter_type = parameter.annotation
178+
parameter_type = decode_optional(parameter_type)
179+
arg_value = decode(arg_value, parameter_type)
180+
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
181+
var_args = arg_value
182+
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
183+
kwargs[param_name] = arg_value
184+
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
185+
raise TypeError(
186+
f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
187+
f"type to a dict or explicitly declare the params"
188+
)
189+
else:
190+
# POSITIONAL or POSITIONAL_OR_KEYWORD
191+
positional_args[param_name] = arg_value
192+
193+
if len(var_args) > 0 and var_args[0] == "--":
194+
var_args = var_args[1:]
195+
196+
return positional_args, var_args, kwargs
197+
198+
143199
def materialize_appdef(
144200
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
145201
cmpnt_args: List[str],
@@ -174,30 +230,12 @@ def materialize_appdef(
174230
An application spec
175231
"""
176232

177-
function_args = []
178-
var_arg = []
179-
kwargs = {}
180-
181-
parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)
182-
183-
parameters = inspect.signature(cmpnt_fn).parameters
184-
for param_name, parameter in parameters.items():
185-
arg_value = getattr(parsed_args, param_name)
186-
parameter_type = parameter.annotation
187-
parameter_type = decode_optional(parameter_type)
188-
arg_value = decode(arg_value, parameter_type)
189-
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
190-
var_arg = arg_value
191-
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
192-
kwargs[param_name] = arg_value
193-
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
194-
raise TypeError("**kwargs are not supported for component definitions")
195-
else:
196-
function_args.append(arg_value)
197-
if len(var_arg) > 0 and var_arg[0] == "--":
198-
var_arg = var_arg[1:]
233+
positional_args, args, kwargs = parse_and_decode_args(
234+
cmpnt_fn, cmpnt_args, cmpnt_defaults, config
235+
)
236+
positional_arg_values = list(positional_args.values())
237+
appdef = cmpnt_fn(*positional_arg_values, *args, **kwargs)
199238

200-
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
201239
if not isinstance(appdef, AppDef):
202240
raise TypeError(
203241
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"

0 commit comments

Comments
 (0)