Skip to content

Commit 383de98

Browse files
author
Ervin T
authored
[refactor] Allow full RunOptions to be specified in trainer configuration YAML (#3815)
1 parent 7439038 commit 383de98

File tree

4 files changed

+162
-25
lines changed

4 files changed

+162
-25
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ you will need to change the signature of its `Write()` method. (#3834)
103103
will allow use with python 3.8 using tensorflow 2.2.0rc3.
104104
- `UnityRLCapabilities` was added to help inform users when RL features are mismatched between C# and Python packages. (#3831)
105105
- Unity Player logs are now written out to the results directory. (#3877)
106+
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
106107

107108
### Bug Fixes
108109

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Set
2+
import argparse
3+
4+
5+
class DetectDefault(argparse.Action):
6+
"""
7+
Internal custom Action to help detect arguments that aren't default.
8+
"""
9+
10+
non_default_args: Set[str] = set()
11+
12+
def __call__(self, arg_parser, namespace, values, option_string=None):
13+
setattr(namespace, self.dest, values)
14+
DetectDefault.non_default_args.add(self.dest)
15+
16+
17+
class DetectDefaultStoreTrue(DetectDefault):
18+
"""
19+
Internal class to help detect arguments that aren't default.
20+
Used for store_true arguments.
21+
"""
22+
23+
def __init__(self, nargs=0, **kwargs):
24+
super().__init__(nargs=nargs, **kwargs)
25+
26+
def __call__(self, arg_parser, namespace, values, option_string=None):
27+
super().__call__(arg_parser, namespace, True, option_string)
28+
29+
30+
class StoreConfigFile(argparse.Action):
31+
"""
32+
Custom Action to store the config file location not as part of the CLI args.
33+
This is because we want to maintain an equivalence between the config file's
34+
contents and the args themselves.
35+
"""
36+
37+
trainer_config_path: str
38+
39+
def __call__(self, arg_parser, namespace, values, option_string=None):
40+
delattr(namespace, self.dest)
41+
StoreConfigFile.trainer_config_path = values

ml-agents/mlagents/trainers/learn.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
GaugeWriter,
2727
ConsoleWriter,
2828
)
29+
from mlagents.trainers.cli_utils import (
30+
StoreConfigFile,
31+
DetectDefault,
32+
DetectDefaultStoreTrue,
33+
)
2934
from mlagents_envs.environment import UnityEnvironment
3035
from mlagents.trainers.sampler_class import SamplerManager
3136
from mlagents.trainers.exception import SamplerException, TrainerConfigError
@@ -48,18 +53,20 @@ def _create_parser():
4853
argparser = argparse.ArgumentParser(
4954
formatter_class=argparse.ArgumentDefaultsHelpFormatter
5055
)
51-
argparser.add_argument("trainer_config_path")
56+
argparser.add_argument("trainer_config_path", action=StoreConfigFile)
5257
argparser.add_argument(
5358
"--env",
5459
default=None,
5560
dest="env_path",
5661
help="Path to the Unity executable to train",
62+
action=DetectDefault,
5763
)
5864
argparser.add_argument(
5965
"--lesson",
6066
default=0,
6167
type=int,
6268
help="The lesson to start with when performing curriculum training",
69+
action=DetectDefault,
6370
)
6471
argparser.add_argument(
6572
"--keep-checkpoints",
@@ -68,19 +75,20 @@ def _create_parser():
6875
help="The maximum number of model checkpoints to keep. Checkpoints are saved after the"
6976
"number of steps specified by the save-freq option. Once the maximum number of checkpoints"
7077
"has been reached, the oldest checkpoint is deleted when saving a new checkpoint.",
78+
action=DetectDefault,
7179
)
7280
argparser.add_argument(
7381
"--load",
7482
default=False,
7583
dest="load_model",
76-
action="store_true",
84+
action=DetectDefaultStoreTrue,
7785
help=argparse.SUPPRESS, # Deprecated but still usable for now.
7886
)
7987
argparser.add_argument(
8088
"--resume",
8189
default=False,
8290
dest="resume",
83-
action="store_true",
91+
action=DetectDefaultStoreTrue,
8492
help="Whether to resume training from a checkpoint. Specify a --run-id to use this option. "
8593
"If set, the training code loads an already trained model to initialize the neural network "
8694
"before resuming training. This option is only valid when the models exist, and have the same "
@@ -90,7 +98,7 @@ def _create_parser():
9098
"--force",
9199
default=False,
92100
dest="force",
93-
action="store_true",
101+
action=DetectDefaultStoreTrue,
94102
help="Whether to force-overwrite this run-id's existing summary and model data. (Without "
95103
"this flag, attempting to train a model with a run-id that has been used before will throw "
96104
"an error.",
@@ -103,6 +111,7 @@ def _create_parser():
103111
"as the saved model itself. If you use TensorBoard to view the training statistics, "
104112
"always set a unique run-id for each training run. (The statistics for all runs with the "
105113
"same id are combined as if they were produced by a the same session.)",
114+
action=DetectDefault,
106115
)
107116
argparser.add_argument(
108117
"--initialize-from",
@@ -112,31 +121,34 @@ def _create_parser():
112121
"This can be used, for instance, to fine-tune an existing model on a new environment. "
113122
"Note that the previously saved models must have the same behavior parameters as your "
114123
"current environment.",
124+
action=DetectDefault,
115125
)
116126
argparser.add_argument(
117127
"--save-freq",
118128
default=50000,
119129
type=int,
120130
help="How often (in steps) to save the model during training",
131+
action=DetectDefault,
121132
)
122133
argparser.add_argument(
123134
"--seed",
124135
default=-1,
125136
type=int,
126137
help="A number to use as a seed for the random number generator used by the training code",
138+
action=DetectDefault,
127139
)
128140
argparser.add_argument(
129141
"--train",
130142
default=False,
131143
dest="train_model",
132-
action="store_true",
144+
action=DetectDefaultStoreTrue,
133145
help=argparse.SUPPRESS,
134146
)
135147
argparser.add_argument(
136148
"--inference",
137149
default=False,
138150
dest="inference",
139-
action="store_true",
151+
action=DetectDefaultStoreTrue,
140152
help="Whether to run in Python inference mode (i.e. no training). Use with --resume to load "
141153
"a model trained with an existing run ID.",
142154
)
@@ -149,25 +161,27 @@ def _create_parser():
149161
"will use the port (base_port + worker_id), where the worker_id is sequential IDs given to "
150162
"each instance from 0 to (num_envs - 1). Note that when training using the Editor rather "
151163
"than an executable, the base port will be ignored.",
164+
action=DetectDefault,
152165
)
153166
argparser.add_argument(
154167
"--num-envs",
155168
default=1,
156169
type=int,
157170
help="The number of concurrent Unity environment instances to collect experiences "
158171
"from when training",
172+
action=DetectDefault,
159173
)
160174
argparser.add_argument(
161175
"--no-graphics",
162176
default=False,
163-
action="store_true",
177+
action=DetectDefaultStoreTrue,
164178
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
165179
"the graphics driver. Use this only if your agents don't use visual observations.",
166180
)
167181
argparser.add_argument(
168182
"--debug",
169183
default=False,
170-
action="store_true",
184+
action=DetectDefaultStoreTrue,
171185
help="Whether to enable debug-level logging for some parts of the code",
172186
)
173187
argparser.add_argument(
@@ -178,11 +192,12 @@ def _create_parser():
178192
"process these as Unity Command Line Arguments. You should choose different argument names if "
179193
"you want to create environment-specific arguments. All arguments after this flag will be "
180194
"passed to the executable.",
195+
action=DetectDefault,
181196
)
182197
argparser.add_argument(
183198
"--cpu",
184199
default=False,
185-
action="store_true",
200+
action=DetectDefaultStoreTrue,
186201
help="Forces training using CPU only",
187202
)
188203

@@ -195,41 +210,47 @@ def _create_parser():
195210
type=int,
196211
help="The width of the executable window of the environment(s) in pixels "
197212
"(ignored for editor training).",
213+
action=DetectDefault,
198214
)
199215
eng_conf.add_argument(
200216
"--height",
201217
default=84,
202218
type=int,
203219
help="The height of the executable window of the environment(s) in pixels "
204220
"(ignored for editor training)",
221+
action=DetectDefault,
205222
)
206223
eng_conf.add_argument(
207224
"--quality-level",
208225
default=5,
209226
type=int,
210227
help="The quality level of the environment(s). Equivalent to calling "
211228
"QualitySettings.SetQualityLevel in Unity.",
229+
action=DetectDefault,
212230
)
213231
eng_conf.add_argument(
214232
"--time-scale",
215233
default=20,
216234
type=float,
217235
help="The time scale of the Unity environment(s). Equivalent to setting "
218236
"Time.timeScale in Unity.",
237+
action=DetectDefault,
219238
)
220239
eng_conf.add_argument(
221240
"--target-frame-rate",
222241
default=-1,
223242
type=int,
224243
help="The target frame rate of the Unity environment(s). Equivalent to setting "
225244
"Application.targetFrameRate in Unity.",
245+
action=DetectDefault,
226246
)
227247
eng_conf.add_argument(
228248
"--capture-frame-rate",
229249
default=60,
230250
type=int,
231251
help="The capture frame rate of the Unity environment(s). Equivalent to setting "
232252
"Time.captureFramerate in Unity.",
253+
action=DetectDefault,
233254
)
234255
return argparser
235256

@@ -277,26 +298,35 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
277298
configs loaded from files.
278299
"""
279300
argparse_args = vars(args)
280-
config_path = argparse_args["trainer_config_path"]
281-
# Load YAML and apply overrides as needed
301+
run_options_dict = {}
302+
run_options_dict.update(argparse_args)
303+
config_path = StoreConfigFile.trainer_config_path
304+
305+
# Load YAML
282306
yaml_config = load_config(config_path)
283-
try:
284-
argparse_args["behaviors"] = yaml_config["behaviors"]
285-
except KeyError:
307+
# This is the only option that is not optional and has no defaults.
308+
if "behaviors" not in yaml_config:
286309
raise TrainerConfigError(
287310
"Trainer configurations not found. Make sure your YAML file has a section for behaviors."
288311
)
312+
# Use the YAML file values for all values not specified in the CLI.
313+
for key, val in yaml_config.items():
314+
# Detect bad config options
315+
if not hasattr(RunOptions, key):
316+
raise TrainerConfigError(
317+
"The option {} was specified in your YAML file, but is invalid.".format(
318+
key
319+
)
320+
)
321+
if key not in DetectDefault.non_default_args:
322+
run_options_dict[key] = val
289323

290-
argparse_args["parameter_randomization"] = yaml_config.get(
291-
"parameter_randomization", None
292-
)
293324
# Keep deprecated --load working, TODO: remove
294-
argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"]
295-
# Since argparse accepts file paths in the config options which don't exist in CommandLineOptions,
296-
# these keys will need to be deleted to use the **/splat operator below.
297-
argparse_args.pop("trainer_config_path")
325+
run_options_dict["resume"] = (
326+
run_options_dict["resume"] or run_options_dict["load_model"]
327+
)
298328

299-
return RunOptions(**vars(args))
329+
return RunOptions(**run_options_dict)
300330

301331

302332
def get_version_string() -> str:

0 commit comments

Comments
 (0)