diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 820b9bb428..09c5cd6aa0 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -24,6 +24,8 @@ Note that PyTorch 1.6.0 or greater should be installed to use this feature; see - The minimum supported version of TensorFlow was increased to 1.14.0. (#4411) - A CNN (`vis_encode_type: match3`) for smaller grids, e.g. board games, has been added. (#4434) +- You can now again specify a default configuration for your behaviors. Specify `default_settings` in +your trainer configuration to do so. (#4448) ### Bug Fixes #### com.unity.ml-agents (C#) diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md index 5327db5e49..bc699fc380 100644 --- a/docs/Training-ML-Agents.md +++ b/docs/Training-ML-Agents.md @@ -337,6 +337,24 @@ each of these parameters mean and provide guidelines on how to set them. See description of all the configurations listed above, along with their defaults. Unless otherwise specified, omitting a configuration will revert it to its default. +### Default Behavior Settings + +In some cases, you may want to specify a set of default configurations for your Behaviors. +This may be useful, for instance, if your Behavior names are generated procedurally by +the environment and not known before runtime, or if you have many Behaviors with very similar +settings. To specify a default configuraton, insert a `default_settings` section in your YAML. +This section should be formatted exactly like a configuration for a Behavior. + +```yaml +default_settings: + # < Same as Behavior configuration > +behaviors: + # < Same as above > +``` + +Behaviors found in the environment that aren't specified in the YAML will now use the `default_settings`, +and unspecified settings in behavior configurations will default to the values in `default_settings` if +specified there. ### Environment Parameters diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 2a1f27b391..8863944ca5 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -2,13 +2,24 @@ import attr import cattr -from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple, Union +from typing import ( + Dict, + Optional, + List, + Any, + DefaultDict, + Mapping, + Tuple, + Union, + ClassVar, +) from enum import Enum import collections import argparse import abc import numpy as np import math +import copy from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser from mlagents.trainers.cli_utils import load_config @@ -46,6 +57,17 @@ def defaultdict_to_dict(d: DefaultDict) -> Dict: return {key: cattr.unstructure(val) for key, val in d.items()} +def deep_update_dict(d: Dict, update_d: Mapping) -> None: + """ + Similar to dict.update(), but works for nested dicts of dicts as well. + """ + for key, val in update_d.items(): + if key in d and isinstance(d[key], Mapping) and isinstance(val, Mapping): + deep_update_dict(d[key], val) + else: + d[key] = val + + class SerializationSettings: convert_to_barracuda = True convert_to_onnx = True @@ -539,6 +561,7 @@ class FrameworkType(Enum): @attr.s(auto_attribs=True) class TrainerSettings(ExportableSettings): + default_override: ClassVar[Optional["TrainerSettings"]] = None trainer_type: TrainerType = TrainerType.PPO hyperparameters: HyperparamSettings = attr.ib() @@ -578,8 +601,8 @@ def _check_batch_size_seq_length(self, attribute, value): @staticmethod def dict_to_defaultdict(d: Dict, t: type) -> DefaultDict: - return collections.defaultdict( - TrainerSettings, cattr.structure(d, Dict[str, TrainerSettings]) + return TrainerSettings.DefaultTrainerDict( + cattr.structure(d, Dict[str, TrainerSettings]) ) @staticmethod @@ -588,10 +611,18 @@ def structure(d: Mapping, t: type) -> Any: Helper method to structure a TrainerSettings class. Meant to be registered with cattr.register_structure_hook() and called with cattr.structure(). """ + if not isinstance(d, Mapping): raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") + d_copy: Dict[str, Any] = {} - d_copy.update(d) + + # Check if a default_settings was specified. If so, used those as the default + # rather than an empty dict. + if TrainerSettings.default_override is not None: + d_copy.update(cattr.unstructure(TrainerSettings.default_override)) + + deep_update_dict(d_copy, d) for key, val in d_copy.items(): if attr.has(type(val)): @@ -613,6 +644,16 @@ def structure(d: Mapping, t: type) -> Any: d_copy[key] = check_and_structure(key, val, t) return t(**d_copy) + class DefaultTrainerDict(collections.defaultdict): + def __init__(self, *args): + super().__init__(TrainerSettings, *args) + + def __missing__(self, key: Any) -> "TrainerSettings": + if TrainerSettings.default_override is not None: + return copy.deepcopy(TrainerSettings.default_override) + else: + return TrainerSettings() + # COMMAND LINE ######################################################################### @attr.s(auto_attribs=True) @@ -653,8 +694,9 @@ class EngineSettings: @attr.s(auto_attribs=True) class RunOptions(ExportableSettings): + default_settings: Optional[TrainerSettings] = None behaviors: DefaultDict[str, TrainerSettings] = attr.ib( - factory=lambda: collections.defaultdict(TrainerSettings) + factory=TrainerSettings.DefaultTrainerDict ) env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings) engine_settings: EngineSettings = attr.ib(factory=EngineSettings) @@ -733,4 +775,12 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": @staticmethod def from_dict(options_dict: Dict[str, Any]) -> "RunOptions": + # If a default settings was specified, set the TrainerSettings class override + if ( + "default_settings" in options_dict.keys() + and options_dict["default_settings"] is not None + ): + TrainerSettings.default_override = cattr.structure( + options_dict["default_settings"], TrainerSettings + ) return cattr.structure(options_dict, RunOptions) diff --git a/ml-agents/mlagents/trainers/tests/test_settings.py b/ml-agents/mlagents/trainers/tests/test_settings.py index 2a9393c1b3..d4cfb1d1b9 100644 --- a/ml-agents/mlagents/trainers/tests/test_settings.py +++ b/ml-agents/mlagents/trainers/tests/test_settings.py @@ -1,4 +1,5 @@ import attr +import cattr import pytest import yaml @@ -20,6 +21,7 @@ GaussianSettings, MultiRangeUniformSettings, TrainerType, + deep_update_dict, strict_to_cls, ) from mlagents.trainers.exception import TrainerConfigError @@ -104,6 +106,14 @@ class TestAttrsClass: strict_to_cls("non_dict_input", TestAttrsClass) +def test_deep_update_dict(): + dict1 = {"a": 1, "b": 2, "c": {"d": 3}} + dict2 = {"a": 2, "c": {"d": 4, "e": 5}} + + deep_update_dict(dict1, dict2) + assert dict1 == {"a": 2, "b": 2, "c": {"d": 4, "e": 5}} + + def test_trainersettings_structure(): """ Test structuring method for TrainerSettings @@ -468,3 +478,25 @@ def test_environment_settings(): # Multiple environments with no env_path is an error with pytest.raises(ValueError): EnvironmentSettings(num_envs=2) + + +def test_default_settings(): + # Make default settings, one nested and one not. + default_settings = {"max_steps": 1, "network_settings": {"num_layers": 1000}} + behaviors = {"test1": {"max_steps": 2, "network_settings": {"hidden_units": 2000}}} + run_options_dict = {"default_settings": default_settings, "behaviors": behaviors} + run_options = RunOptions.from_dict(run_options_dict) + + # Check that a new behavior has the default settings + default_settings_cls = cattr.structure(default_settings, TrainerSettings) + check_if_different(default_settings_cls, run_options.behaviors["test2"]) + + # Check that an existing beehavior overrides the defaults in specified fields + test1_settings = run_options.behaviors["test1"] + assert test1_settings.max_steps == 2 + assert test1_settings.network_settings.hidden_units == 2000 + assert test1_settings.network_settings.num_layers == 1000 + # Change the overridden fields back, and check if the rest are equal. + test1_settings.max_steps = 1 + test1_settings.network_settings.hidden_units == default_settings_cls.network_settings.hidden_units + check_if_different(test1_settings, default_settings_cls)