25
25
from mlagents .trainers .sampler_class import SamplerManager
26
26
from mlagents .trainers .exception import SamplerException
27
27
from mlagents .trainers .settings import RunOptions
28
+ from mlagents .trainers .training_status import GlobalTrainingStatus
28
29
from mlagents_envs .base_env import BaseEnv
29
30
from mlagents .trainers .subprocess_env_manager import SubprocessEnvManager
30
31
from mlagents_envs .side_channel .side_channel import SideChannel
38
39
39
40
logger = logging_util .get_logger (__name__ )
40
41
42
+ TRAINING_STATUS_FILE_NAME = "training_status.json"
43
+
41
44
42
45
def get_version_string () -> str :
43
46
# pylint: disable=no-member
@@ -82,6 +85,11 @@ def run_training(run_seed: int, options: RunOptions) -> None:
82
85
)
83
86
# Make run logs directory
84
87
os .makedirs (run_logs_dir , exist_ok = True )
88
+ # Load any needed states
89
+ if checkpoint_settings .resume :
90
+ GlobalTrainingStatus .load_state (
91
+ os .path .join (run_logs_dir , "training_status.json" )
92
+ )
85
93
# Configure CSV, Tensorboard Writers and StatsReporter
86
94
# We assume reward and episode length are needed in the CSV.
87
95
csv_writer = CSVWriter (
@@ -123,7 +131,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
123
131
env_factory , engine_config , env_settings .num_envs
124
132
)
125
133
maybe_meta_curriculum = try_create_meta_curriculum (
126
- options .curriculum , env_manager , checkpoint_settings .lesson
134
+ options .curriculum , env_manager , restore = checkpoint_settings .resume
127
135
)
128
136
sampler_manager , resampling_interval = create_sampler_manager (
129
137
options .parameter_randomization , run_seed
@@ -159,6 +167,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
159
167
env_manager .close ()
160
168
write_run_options (write_path , options )
161
169
write_timing_tree (run_logs_dir )
170
+ write_training_status (run_logs_dir )
162
171
163
172
164
173
def write_run_options (output_dir : str , run_options : RunOptions ) -> None :
@@ -175,6 +184,10 @@ def write_run_options(output_dir: str, run_options: RunOptions) -> None:
175
184
)
176
185
177
186
187
+ def write_training_status (output_dir : str ) -> None :
188
+ GlobalTrainingStatus .save_state (os .path .join (output_dir , TRAINING_STATUS_FILE_NAME ))
189
+
190
+
178
191
def write_timing_tree (output_dir : str ) -> None :
179
192
timing_path = os .path .join (output_dir , "timers.json" )
180
193
try :
@@ -209,15 +222,14 @@ def create_sampler_manager(sampler_config, run_seed=None):
209
222
210
223
211
224
def try_create_meta_curriculum (
212
- curriculum_config : Optional [Dict ], env : SubprocessEnvManager , lesson : int
225
+ curriculum_config : Optional [Dict ], env : SubprocessEnvManager , restore : bool = False
213
226
) -> Optional [MetaCurriculum ]:
214
227
if curriculum_config is None or len (curriculum_config ) <= 0 :
215
228
return None
216
229
else :
217
230
meta_curriculum = MetaCurriculum (curriculum_config )
218
- # TODO: Should be able to start learning at different lesson numbers
219
- # for each curriculum.
220
- meta_curriculum .set_all_curricula_to_lesson_num (lesson )
231
+ if restore :
232
+ meta_curriculum .try_restore_all_curriculum ()
221
233
return meta_curriculum
222
234
223
235
0 commit comments