|
10 | 10 | import inspect
|
11 | 11 | import os
|
12 | 12 | from argparse import Namespace
|
13 |
| -from typing import Any, Callable, Dict, List, Mapping, Optional, Union |
| 13 | +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union |
14 | 14 |
|
15 | 15 | from torchx.specs.api import BindMount, MountType, VolumeMount
|
16 | 16 | from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
|
@@ -140,6 +140,62 @@ def parse_args(
|
140 | 140 | return parsed_args
|
141 | 141 |
|
142 | 142 |
|
| 143 | +def parse_and_decode_args( |
| 144 | + cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type |
| 145 | + cmpnt_args: list[str], |
| 146 | + cmpnt_args_defaults: Optional[Dict[str, Any]] = None, |
| 147 | + config: Optional[Dict[str, Any]] = None, |
| 148 | +) -> Tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: |
| 149 | + """ |
| 150 | + Parses and decodes command-line arguments for a component function. |
| 151 | +
|
| 152 | + This function takes a component function and its arguments, parses them using argparse, |
| 153 | + and decodes the arguments into their expected types based on the function's signature. |
| 154 | + It separates positional arguments, variable positional arguments (*args), and keyword-only arguments. |
| 155 | +
|
| 156 | + Args: |
| 157 | + cmpnt_fn: The component function whose arguments are to be parsed and decoded. |
| 158 | + cmpnt_args: List of command-line arguments to be parsed. |
| 159 | + cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters. |
| 160 | + config: Optional dictionary containing additional configuration values. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + A tuple containing: |
| 164 | + - positional_args: Dictionary of positional and positional-or-keyword arguments. |
| 165 | + - var_args: List of variable positional arguments (*args). |
| 166 | + - kwargs: Dictionary of keyword-only arguments. |
| 167 | + """ |
| 168 | + parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config) |
| 169 | + |
| 170 | + positional_args = {} |
| 171 | + var_args = [] |
| 172 | + kwargs = {} |
| 173 | + |
| 174 | + parameters = inspect.signature(cmpnt_fn).parameters |
| 175 | + for param_name, parameter in parameters.items(): |
| 176 | + arg_value = getattr(parsed_args, param_name) |
| 177 | + parameter_type = parameter.annotation |
| 178 | + parameter_type = decode_optional(parameter_type) |
| 179 | + arg_value = decode(arg_value, parameter_type) |
| 180 | + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: |
| 181 | + var_args = arg_value |
| 182 | + elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: |
| 183 | + kwargs[param_name] = arg_value |
| 184 | + elif parameter.kind == inspect.Parameter.VAR_KEYWORD: |
| 185 | + raise TypeError( |
| 186 | + f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the " |
| 187 | + f"type to a dict or explicitly declare the params" |
| 188 | + ) |
| 189 | + else: |
| 190 | + # POSITIONAL or POSITIONAL_OR_KEYWORD |
| 191 | + positional_args[param_name] = arg_value |
| 192 | + |
| 193 | + if len(var_args) > 0 and var_args[0] == "--": |
| 194 | + var_args = var_args[1:] |
| 195 | + |
| 196 | + return positional_args, var_args, kwargs |
| 197 | + |
| 198 | + |
143 | 199 | def materialize_appdef(
|
144 | 200 | cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
|
145 | 201 | cmpnt_args: List[str],
|
@@ -174,30 +230,12 @@ def materialize_appdef(
|
174 | 230 | An application spec
|
175 | 231 | """
|
176 | 232 |
|
177 |
| - function_args = [] |
178 |
| - var_arg = [] |
179 |
| - kwargs = {} |
180 |
| - |
181 |
| - parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config) |
182 |
| - |
183 |
| - parameters = inspect.signature(cmpnt_fn).parameters |
184 |
| - for param_name, parameter in parameters.items(): |
185 |
| - arg_value = getattr(parsed_args, param_name) |
186 |
| - parameter_type = parameter.annotation |
187 |
| - parameter_type = decode_optional(parameter_type) |
188 |
| - arg_value = decode(arg_value, parameter_type) |
189 |
| - if parameter.kind == inspect.Parameter.VAR_POSITIONAL: |
190 |
| - var_arg = arg_value |
191 |
| - elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: |
192 |
| - kwargs[param_name] = arg_value |
193 |
| - elif parameter.kind == inspect.Parameter.VAR_KEYWORD: |
194 |
| - raise TypeError("**kwargs are not supported for component definitions") |
195 |
| - else: |
196 |
| - function_args.append(arg_value) |
197 |
| - if len(var_arg) > 0 and var_arg[0] == "--": |
198 |
| - var_arg = var_arg[1:] |
| 233 | + positional_args, args, kwargs = parse_and_decode_args( |
| 234 | + cmpnt_fn, cmpnt_args, cmpnt_defaults, config |
| 235 | + ) |
| 236 | + positional_arg_values = list(positional_args.values()) |
| 237 | + appdef = cmpnt_fn(*positional_arg_values, *args, **kwargs) |
199 | 238 |
|
200 |
| - appdef = cmpnt_fn(*function_args, *var_arg, **kwargs) |
201 | 239 | if not isinstance(appdef, AppDef):
|
202 | 240 | raise TypeError(
|
203 | 241 | f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
|
|
0 commit comments