Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 163 additions & 89 deletions fiddle/_src/absl_flags/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

"""API to use command line flags with Fiddle Buildables."""

import dataclasses
import re
import types
from typing import Any, Optional, TypeVar
from typing import Any, List, Optional, Text, TypeVar, Union

from absl import flags
from etils import epath
Expand Down Expand Up @@ -81,60 +82,32 @@ def serialize(self, value: config.Buildable) -> str:
return f"config_str:{serialized}"


class FiddleFlag(flags.MultiFlag):
"""ABSL flag class for a Fiddle config flag.
@dataclasses.dataclass
class _LazyFlagValue:
"""Represents a lazily evaluated Fiddle flag value.

This class is used to parse command line flags to construct a Fiddle `Config`
object with certain transformations applied as specified in the command line
flags.
This is separate from FiddleFlag because it is used by both defaults and
provided flags.

Most users should rely on the `DEFINE_fiddle_config()` API below. Using this
class directly provides flexibility to users to parse Fiddle flags themselves
programmatically. Also see the documentation for `DEFINE_fiddle_config()`
below.
Lazy flag values are useful because they allow other parts of the system to
be set up, so things like logging can be configured before a configuration is
loaded.
"""

Example usage where this flag is parsed from existing flag:
```
from fiddle import absl_flags as fdl_flags
flag_name: str
remaining_directives: List[str] = dataclasses.field(default_factory=list)
first_command: Optional[str] = None
initial_config_expression: Optional[str] = None

_MY_CONFIG = fdl_flags.DEFINE_multi_string(
"my_config",
"Name of the fiddle config"
)
default_module: Optional[types.ModuleType] = None
allow_imports: bool = True
pyref_policy: Optional[serialization.PyrefPolicy] = None

fiddle_flag = fdl_flags.FiddleFlag(
name="config",
default_module=my_module,
default=None,
parser=flags.ArgumentParser(),
serializer=None,
help_string="My fiddle flag",
)
fiddle_flag.parse(_MY_CONFIG.value)
config = fiddle_flag.value
```
"""

def __init__(
self,
*args,
default_module: Optional[types.ModuleType] = None,
allow_imports: bool = True,
pyref_policy: Optional[serialization.PyrefPolicy] = None,
**kwargs,
):
self.allow_imports = allow_imports
self.default_module = default_module
self._pyref_policy = pyref_policy
self.first_command = None
self._initial_config_expression = None
# A `directive` is a str of the form e.g. 'config:...'.
# Due to the lazy evaluation of `value`, this list is needed to keep
# track of the remaining `directives`.
self._remaining_directives = []
super().__init__(*args, **kwargs)
# Only set internally, please use get_value() / set_value().
_value: Optional[Any] = None

def _initial_config(self, expression: str):
"""Generates the initial config from a config:<expression> directive."""
call_expr = utils.CallExpression.parse(expression)
base_name = call_expr.func_name
base_fn = utils.resolve_function_reference(
Expand All @@ -150,6 +123,7 @@ def _initial_config(self, expression: str):
return base_fn(*call_expr.args, **call_expr.kwargs)

def _apply_fiddler(self, cfg: config.Buildable, expression: str):
"""Modifies the config from the given CLI flag."""
call_expr = utils.CallExpression.parse(expression)
base_name = call_expr.func_name
fiddler = utils.resolve_function_reference(
Expand All @@ -175,57 +149,40 @@ def _apply_fiddler(self, cfg: config.Buildable, expression: str):
# `fdl.Buildable` object.
return new_cfg if new_cfg is not None else cfg

def parse(self, arguments):
new_parsed = self._parse(arguments)
self._remaining_directives.extend(new_parsed)
self.present += len(new_parsed)

def unparse(self) -> None:
self.value = self.default
self.using_default_value = True
# Reset it so that all `directives` not being processed yet will be
# discarded.
self._remaining_directives = []
self.present = 0

def _parse_config(self, command: str, expression: str) -> None:
if self._initial_config_expression:
"""Sets the initial config from the given CLI flag/directive."""
if self.initial_config_expression:
raise ValueError(
"Only one base configuration is permitted. Received"
f"{command}:{expression} after "
f"{self.first_command}:{self._initial_config_expression} was"
f"{self.first_command}:{self.initial_config_expression} was"
" already provided."
)
else:
self._initial_config_expression = expression
self.initial_config_expression = expression
if command == "config":
self.value = self._initial_config(expression)
self._value = self._initial_config(expression)
elif command == "config_file":
with epath.Path(expression).open() as f:
self.value = serialization.load_json(
f.read(), pyref_policy=self._pyref_policy
self._value = serialization.load_json(
f.read(), pyref_policy=self.pyref_policy
)
elif command == "config_str":
serializer = utils.ZlibJSONSerializer()
self.value = serializer.deserialize(
expression, pyref_policy=self._pyref_policy
self._value = serializer.deserialize(
expression, pyref_policy=self.pyref_policy
)

def _serialize(self, value) -> str:
# Skip MultiFlag serialization as we don't truly have a multi-flag.
# This will invoke Flag._serialize
return super(flags.MultiFlag, self)._serialize(value)

@property
def value(self):
while self._remaining_directives:
def get_value(self):
"""Gets the current value (parsing any directives)."""
while self.remaining_directives:
# Pop already processed `directive` so that _value won't be updated twice
# by the same argument.
item = self._remaining_directives.pop(0)
item = self.remaining_directives.pop(0)
match = _COMMAND_RE.fullmatch(item)
if not match:
raise ValueError(
f"All flag values to {self.name} must begin with 'config:', "
f"All flag values to {self.flag_name} must begin with 'config:', "
"'config_file:', 'config_str:', 'set:', or 'fiddler:'."
)
command, expression = match.groups()
Expand All @@ -235,7 +192,9 @@ def value(self):
raise ValueError(
"First flag command must specify the input config via either "
"config or config_file or config_str commands. "
f"Received command: {command} instead."
f"Received command: {command} instead. If you have a default "
"value set, you must re-provide that on the CLI before setting "
"values or running fiddlers."
)
self.first_command = command

Expand All @@ -254,15 +213,130 @@ def value(self):
raise AssertionError("Internal error; should not be reached.")
return self._value

def set_value(self, value: Any):
self._value = value


class FiddleFlag(flags.MultiFlag):
"""ABSL flag class for a Fiddle config flag.

This class is used to parse command line flags to construct a Fiddle `Config`
object with certain transformations applied as specified in the command line
flags.

Most users should rely on the `DEFINE_fiddle_config()` API below. Using this
class directly provides flexibility to users to parse Fiddle flags themselves
programmatically. Also see the documentation for `DEFINE_fiddle_config()`
below.

Example usage where this flag is parsed from existing flag:
```
from fiddle import absl_flags as fdl_flags

_MY_CONFIG = fdl_flags.DEFINE_multi_string(
"my_config",
"Name of the fiddle config"
)

fiddle_flag = fdl_flags.FiddleFlag(
name="config",
default_module=my_module,
default=None,
parser=flags.ArgumentParser(),
serializer=None,
help_string="My fiddle flag",
)
fiddle_flag.parse(_MY_CONFIG.value)
config = fiddle_flag.value
```
"""

def __init__(
self,
*args,
name: Text,
default_module: Optional[types.ModuleType] = None,
allow_imports: bool = True,
pyref_policy: Optional[serialization.PyrefPolicy] = None,
**kwargs,
):
self.allow_imports = allow_imports
self.default_module = default_module
self._pyref_policy = pyref_policy
self._lazy_default = _LazyFlagValue(
flag_name=name,
default_module=default_module,
allow_imports=allow_imports,
pyref_policy=pyref_policy,
)
self._lazy_value = _LazyFlagValue(
flag_name=name,
default_module=default_module,
allow_imports=allow_imports,
pyref_policy=pyref_policy,
)
kwargs["name"] = name
super().__init__(*args, **kwargs)

def parse(self, arguments):
new_parsed = self._parse(arguments)
self._lazy_value.remaining_directives.extend(new_parsed)
self.present += len(new_parsed)

def _parse_from_default(
self, value: Union[Text, List[Any]]
) -> Optional[List[Any]]:
lazy_default_value = _LazyFlagValue(
flag_name=self.name,
default_module=self.default_module,
allow_imports=self.allow_imports,
pyref_policy=self._pyref_policy,
)
value = self._parse(value)
assert isinstance(value, list)
lazy_default_value.remaining_directives.extend(value)
return lazy_default_value # pytype: disable=bad-return-type

def unparse(self) -> None:
self.value = self.default
self.using_default_value = True
# Reset it so that all `directives` not being processed yet will be
# discarded.
self._lazy_value.remaining_directives = []
self.present = 0

def _serialize(self, value) -> str:
# Skip MultiFlag serialization as we don't truly have a multi-flag.
# This will invoke Flag._serialize
return super(flags.MultiFlag, self)._serialize(value)

@property
def value(self):
return self._lazy_value.get_value()

@value.setter
def value(self, value):
self._value = value
self._lazy_value.set_value(value)

@property
def default(self):
return self._lazy_default.get_value()

@default.setter
def default(self, value):
if isinstance(value, _LazyFlagValue):
# Note: This is only for _set_default(). We might choose to override that
# instead of just _parse_from_default(), in which case this branch can be
# removed.
self._lazy_default = value
else:
self._lazy_default.set_value(value)


def DEFINE_fiddle_config( # pylint: disable=invalid-name
name: str,
*,
default: Any = None,
default_flag_str: Optional[str] = None,
help_string: str,
default_module: Optional[types.ModuleType] = None,
pyref_policy: Optional[serialization.PyrefPolicy] = None,
Expand Down Expand Up @@ -317,12 +391,12 @@ def main(argv) -> None:
python3 -m path.to.my.binary --my_config=config_file:path/to/file

Args:
name: name of the command line flag.
default: default value of the flag.
help_string: help string describing what the flag does.
default_module: the python module where this flag is defined.
pyref_policy: a policy for importing references to Python objects.
flag_values: the ``FlagValues`` instance with which the flag will be
name: Name of the command line flag.
default_flag_str: Default value of the flag.
help_string: Help string describing what the flag does.
default_module: The python module where this flag is defined.
pyref_policy: A policy for importing references to Python objects.
flag_values: The ``FlagValues`` instance with which the flag will be
registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
Expand All @@ -334,7 +408,7 @@ def main(argv) -> None:
FiddleFlag(
name=name,
default_module=default_module,
default=default,
default=default_flag_str,
pyref_policy=pyref_policy,
parser=flags.ArgumentParser(),
serializer=FiddleFlagSerializer(pyref_policy=pyref_policy),
Expand Down