26
26
GaugeWriter ,
27
27
ConsoleWriter ,
28
28
)
29
+ from mlagents .trainers .cli_utils import (
30
+ StoreConfigFile ,
31
+ DetectDefault ,
32
+ DetectDefaultStoreTrue ,
33
+ )
29
34
from mlagents_envs .environment import UnityEnvironment
30
35
from mlagents .trainers .sampler_class import SamplerManager
31
36
from mlagents .trainers .exception import SamplerException , TrainerConfigError
@@ -48,18 +53,20 @@ def _create_parser():
48
53
argparser = argparse .ArgumentParser (
49
54
formatter_class = argparse .ArgumentDefaultsHelpFormatter
50
55
)
51
- argparser .add_argument ("trainer_config_path" )
56
+ argparser .add_argument ("trainer_config_path" , action = StoreConfigFile )
52
57
argparser .add_argument (
53
58
"--env" ,
54
59
default = None ,
55
60
dest = "env_path" ,
56
61
help = "Path to the Unity executable to train" ,
62
+ action = DetectDefault ,
57
63
)
58
64
argparser .add_argument (
59
65
"--lesson" ,
60
66
default = 0 ,
61
67
type = int ,
62
68
help = "The lesson to start with when performing curriculum training" ,
69
+ action = DetectDefault ,
63
70
)
64
71
argparser .add_argument (
65
72
"--keep-checkpoints" ,
@@ -68,19 +75,20 @@ def _create_parser():
68
75
help = "The maximum number of model checkpoints to keep. Checkpoints are saved after the"
69
76
"number of steps specified by the save-freq option. Once the maximum number of checkpoints"
70
77
"has been reached, the oldest checkpoint is deleted when saving a new checkpoint." ,
78
+ action = DetectDefault ,
71
79
)
72
80
argparser .add_argument (
73
81
"--load" ,
74
82
default = False ,
75
83
dest = "load_model" ,
76
- action = "store_true" ,
84
+ action = DetectDefaultStoreTrue ,
77
85
help = argparse .SUPPRESS , # Deprecated but still usable for now.
78
86
)
79
87
argparser .add_argument (
80
88
"--resume" ,
81
89
default = False ,
82
90
dest = "resume" ,
83
- action = "store_true" ,
91
+ action = DetectDefaultStoreTrue ,
84
92
help = "Whether to resume training from a checkpoint. Specify a --run-id to use this option. "
85
93
"If set, the training code loads an already trained model to initialize the neural network "
86
94
"before resuming training. This option is only valid when the models exist, and have the same "
@@ -90,7 +98,7 @@ def _create_parser():
90
98
"--force" ,
91
99
default = False ,
92
100
dest = "force" ,
93
- action = "store_true" ,
101
+ action = DetectDefaultStoreTrue ,
94
102
help = "Whether to force-overwrite this run-id's existing summary and model data. (Without "
95
103
"this flag, attempting to train a model with a run-id that has been used before will throw "
96
104
"an error." ,
@@ -103,6 +111,7 @@ def _create_parser():
103
111
"as the saved model itself. If you use TensorBoard to view the training statistics, "
104
112
"always set a unique run-id for each training run. (The statistics for all runs with the "
105
113
"same id are combined as if they were produced by a the same session.)" ,
114
+ action = DetectDefault ,
106
115
)
107
116
argparser .add_argument (
108
117
"--initialize-from" ,
@@ -112,31 +121,34 @@ def _create_parser():
112
121
"This can be used, for instance, to fine-tune an existing model on a new environment. "
113
122
"Note that the previously saved models must have the same behavior parameters as your "
114
123
"current environment." ,
124
+ action = DetectDefault ,
115
125
)
116
126
argparser .add_argument (
117
127
"--save-freq" ,
118
128
default = 50000 ,
119
129
type = int ,
120
130
help = "How often (in steps) to save the model during training" ,
131
+ action = DetectDefault ,
121
132
)
122
133
argparser .add_argument (
123
134
"--seed" ,
124
135
default = - 1 ,
125
136
type = int ,
126
137
help = "A number to use as a seed for the random number generator used by the training code" ,
138
+ action = DetectDefault ,
127
139
)
128
140
argparser .add_argument (
129
141
"--train" ,
130
142
default = False ,
131
143
dest = "train_model" ,
132
- action = "store_true" ,
144
+ action = DetectDefaultStoreTrue ,
133
145
help = argparse .SUPPRESS ,
134
146
)
135
147
argparser .add_argument (
136
148
"--inference" ,
137
149
default = False ,
138
150
dest = "inference" ,
139
- action = "store_true" ,
151
+ action = DetectDefaultStoreTrue ,
140
152
help = "Whether to run in Python inference mode (i.e. no training). Use with --resume to load "
141
153
"a model trained with an existing run ID." ,
142
154
)
@@ -149,25 +161,27 @@ def _create_parser():
149
161
"will use the port (base_port + worker_id), where the worker_id is sequential IDs given to "
150
162
"each instance from 0 to (num_envs - 1). Note that when training using the Editor rather "
151
163
"than an executable, the base port will be ignored." ,
164
+ action = DetectDefault ,
152
165
)
153
166
argparser .add_argument (
154
167
"--num-envs" ,
155
168
default = 1 ,
156
169
type = int ,
157
170
help = "The number of concurrent Unity environment instances to collect experiences "
158
171
"from when training" ,
172
+ action = DetectDefault ,
159
173
)
160
174
argparser .add_argument (
161
175
"--no-graphics" ,
162
176
default = False ,
163
- action = "store_true" ,
177
+ action = DetectDefaultStoreTrue ,
164
178
help = "Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
165
179
"the graphics driver. Use this only if your agents don't use visual observations." ,
166
180
)
167
181
argparser .add_argument (
168
182
"--debug" ,
169
183
default = False ,
170
- action = "store_true" ,
184
+ action = DetectDefaultStoreTrue ,
171
185
help = "Whether to enable debug-level logging for some parts of the code" ,
172
186
)
173
187
argparser .add_argument (
@@ -178,11 +192,12 @@ def _create_parser():
178
192
"process these as Unity Command Line Arguments. You should choose different argument names if "
179
193
"you want to create environment-specific arguments. All arguments after this flag will be "
180
194
"passed to the executable." ,
195
+ action = DetectDefault ,
181
196
)
182
197
argparser .add_argument (
183
198
"--cpu" ,
184
199
default = False ,
185
- action = "store_true" ,
200
+ action = DetectDefaultStoreTrue ,
186
201
help = "Forces training using CPU only" ,
187
202
)
188
203
@@ -195,41 +210,47 @@ def _create_parser():
195
210
type = int ,
196
211
help = "The width of the executable window of the environment(s) in pixels "
197
212
"(ignored for editor training)." ,
213
+ action = DetectDefault ,
198
214
)
199
215
eng_conf .add_argument (
200
216
"--height" ,
201
217
default = 84 ,
202
218
type = int ,
203
219
help = "The height of the executable window of the environment(s) in pixels "
204
220
"(ignored for editor training)" ,
221
+ action = DetectDefault ,
205
222
)
206
223
eng_conf .add_argument (
207
224
"--quality-level" ,
208
225
default = 5 ,
209
226
type = int ,
210
227
help = "The quality level of the environment(s). Equivalent to calling "
211
228
"QualitySettings.SetQualityLevel in Unity." ,
229
+ action = DetectDefault ,
212
230
)
213
231
eng_conf .add_argument (
214
232
"--time-scale" ,
215
233
default = 20 ,
216
234
type = float ,
217
235
help = "The time scale of the Unity environment(s). Equivalent to setting "
218
236
"Time.timeScale in Unity." ,
237
+ action = DetectDefault ,
219
238
)
220
239
eng_conf .add_argument (
221
240
"--target-frame-rate" ,
222
241
default = - 1 ,
223
242
type = int ,
224
243
help = "The target frame rate of the Unity environment(s). Equivalent to setting "
225
244
"Application.targetFrameRate in Unity." ,
245
+ action = DetectDefault ,
226
246
)
227
247
eng_conf .add_argument (
228
248
"--capture-frame-rate" ,
229
249
default = 60 ,
230
250
type = int ,
231
251
help = "The capture frame rate of the Unity environment(s). Equivalent to setting "
232
252
"Time.captureFramerate in Unity." ,
253
+ action = DetectDefault ,
233
254
)
234
255
return argparser
235
256
@@ -277,26 +298,35 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
277
298
configs loaded from files.
278
299
"""
279
300
argparse_args = vars (args )
280
- config_path = argparse_args ["trainer_config_path" ]
281
- # Load YAML and apply overrides as needed
301
+ run_options_dict = {}
302
+ run_options_dict .update (argparse_args )
303
+ config_path = StoreConfigFile .trainer_config_path
304
+
305
+ # Load YAML
282
306
yaml_config = load_config (config_path )
283
- try :
284
- argparse_args ["behaviors" ] = yaml_config ["behaviors" ]
285
- except KeyError :
307
+ # This is the only option that is not optional and has no defaults.
308
+ if "behaviors" not in yaml_config :
286
309
raise TrainerConfigError (
287
310
"Trainer configurations not found. Make sure your YAML file has a section for behaviors."
288
311
)
312
+ # Use the YAML file values for all values not specified in the CLI.
313
+ for key , val in yaml_config .items ():
314
+ # Detect bad config options
315
+ if not hasattr (RunOptions , key ):
316
+ raise TrainerConfigError (
317
+ "The option {} was specified in your YAML file, but is invalid." .format (
318
+ key
319
+ )
320
+ )
321
+ if key not in DetectDefault .non_default_args :
322
+ run_options_dict [key ] = val
289
323
290
- argparse_args ["parameter_randomization" ] = yaml_config .get (
291
- "parameter_randomization" , None
292
- )
293
324
# Keep deprecated --load working, TODO: remove
294
- argparse_args ["resume" ] = argparse_args ["resume" ] or argparse_args ["load_model" ]
295
- # Since argparse accepts file paths in the config options which don't exist in CommandLineOptions,
296
- # these keys will need to be deleted to use the **/splat operator below.
297
- argparse_args .pop ("trainer_config_path" )
325
+ run_options_dict ["resume" ] = (
326
+ run_options_dict ["resume" ] or run_options_dict ["load_model" ]
327
+ )
298
328
299
- return RunOptions (** vars ( args ) )
329
+ return RunOptions (** run_options_dict )
300
330
301
331
302
332
def get_version_string () -> str :
0 commit comments