Skip to content

Commit 5d02292

Browse files
author
Ervin T
authored
[refactor] Store and restore state along with checkpoints (#4025)
1 parent f4d4848 commit 5d02292

File tree

13 files changed

+242
-46
lines changed

13 files changed

+242
-46
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@ and this project adheres to
1717
- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
1818
were replaced by `allow_multiple_obs` which allows one or more visual observations and
1919
vector observations to be used simultaneously. (#3981) Thank you @shakenes !
20-
### Minor Changes
21-
#### com.unity.ml-agents (C#)
22-
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
23-
observations via reflection. (#3925, #4006)
24-
#### ml-agents / ml-agents-envs / gym-unity (Python)
2520
- Curriculum and Parameter Randomization configurations have been merged
2621
into the main training configuration file. Note that this means training
2722
configuration files are now environment-specific. (#3791)
2823
- The format for trainer configuration has changed, and the "default" behavior has been deprecated.
2924
See the [Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Migrating.md) for more details. (#3936)
3025
- Training artifacts (trained models, summaries) are now found in the `results/`
3126
directory. (#3829)
27+
- When using Curriculum, the current lesson will resume if training is quit and resumed. As such,
28+
the `--lesson` CLI option has been removed. (#4025)
29+
### Minor Changes
30+
#### com.unity.ml-agents (C#)
31+
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
32+
observations via reflection. (#3925, #4006)
33+
#### ml-agents / ml-agents-envs / gym-unity (Python)
3234
- Unity Player logs are now written out to the results directory. (#3877)
3335
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
3436
- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents,

docs/Migrating.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ double-check that the versions are in the same. The versions can be found in
2828
- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
2929
were replaced by `allow_multiple_obs` which allows one or more visual observations and
3030
vector observations to be used simultaneously.
31+
- `--lesson` has been removed from the CLI. Lessons will resume when using `--resume`.
32+
To start at a different lesson, modify your Curriculum configuration.
3133

3234
### Steps to Migrate
3335
- To upgrade your configuration files, an upgrade script has been provided. Run `python config/update_config.py

docs/Training-ML-Agents.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,9 @@ train agents in the Wall Jump environment with curriculum learning, we can run:
420420
mlagents-learn config/ppo/WallJump_curriculum.yaml --run-id=wall-jump-curriculum
421421
```
422422

423-
We can then keep track of the current lessons and progresses via TensorBoard.
424-
425-
**Note**: If you are resuming a training session that uses curriculum, please
426-
pass the number of the last-reached lesson using the `--lesson` flag when
427-
running `mlagents-learn`.
423+
We can then keep track of the current lessons and progresses via TensorBoard. If you've terminated
424+
the run, you can resume it using `--resume` and lesson progress will start off where it
425+
ended.
428426

429427
### Environment Parameter Randomization
430428

ml-agents/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,3 @@ scene with the ML-Agents SDK, check out the main
3131
cooperative behavior among different agents is not stable.
3232
- Resuming self-play from a checkpoint resets the reported ELO to the default
3333
value.
34-
- Resuming curriculum learning from a checkpoint requires the last lesson be
35-
specified using the `--lesson` CLI option

ml-agents/mlagents/trainers/cli_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ def _create_parser() -> argparse.ArgumentParser:
5959
help="Path to the Unity executable to train",
6060
action=DetectDefault,
6161
)
62-
argparser.add_argument(
63-
"--lesson",
64-
default=0,
65-
type=int,
66-
help="The lesson to start with when performing curriculum training",
67-
action=DetectDefault,
68-
)
6962
argparser.add_argument(
7063
"--load",
7164
default=False,

ml-agents/mlagents/trainers/learn.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mlagents.trainers.sampler_class import SamplerManager
2626
from mlagents.trainers.exception import SamplerException
2727
from mlagents.trainers.settings import RunOptions
28+
from mlagents.trainers.training_status import GlobalTrainingStatus
2829
from mlagents_envs.base_env import BaseEnv
2930
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
3031
from mlagents_envs.side_channel.side_channel import SideChannel
@@ -38,6 +39,8 @@
3839

3940
logger = logging_util.get_logger(__name__)
4041

42+
TRAINING_STATUS_FILE_NAME = "training_status.json"
43+
4144

4245
def get_version_string() -> str:
4346
# pylint: disable=no-member
@@ -82,6 +85,11 @@ def run_training(run_seed: int, options: RunOptions) -> None:
8285
)
8386
# Make run logs directory
8487
os.makedirs(run_logs_dir, exist_ok=True)
88+
# Load any needed states
89+
if checkpoint_settings.resume:
90+
GlobalTrainingStatus.load_state(
91+
os.path.join(run_logs_dir, "training_status.json")
92+
)
8593
# Configure CSV, Tensorboard Writers and StatsReporter
8694
# We assume reward and episode length are needed in the CSV.
8795
csv_writer = CSVWriter(
@@ -123,7 +131,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
123131
env_factory, engine_config, env_settings.num_envs
124132
)
125133
maybe_meta_curriculum = try_create_meta_curriculum(
126-
options.curriculum, env_manager, checkpoint_settings.lesson
134+
options.curriculum, env_manager, restore=checkpoint_settings.resume
127135
)
128136
sampler_manager, resampling_interval = create_sampler_manager(
129137
options.parameter_randomization, run_seed
@@ -159,6 +167,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
159167
env_manager.close()
160168
write_run_options(write_path, options)
161169
write_timing_tree(run_logs_dir)
170+
write_training_status(run_logs_dir)
162171

163172

164173
def write_run_options(output_dir: str, run_options: RunOptions) -> None:
@@ -175,6 +184,10 @@ def write_run_options(output_dir: str, run_options: RunOptions) -> None:
175184
)
176185

177186

187+
def write_training_status(output_dir: str) -> None:
188+
GlobalTrainingStatus.save_state(os.path.join(output_dir, TRAINING_STATUS_FILE_NAME))
189+
190+
178191
def write_timing_tree(output_dir: str) -> None:
179192
timing_path = os.path.join(output_dir, "timers.json")
180193
try:
@@ -209,15 +222,14 @@ def create_sampler_manager(sampler_config, run_seed=None):
209222

210223

211224
def try_create_meta_curriculum(
212-
curriculum_config: Optional[Dict], env: SubprocessEnvManager, lesson: int
225+
curriculum_config: Optional[Dict], env: SubprocessEnvManager, restore: bool = False
213226
) -> Optional[MetaCurriculum]:
214227
if curriculum_config is None or len(curriculum_config) <= 0:
215228
return None
216229
else:
217230
meta_curriculum = MetaCurriculum(curriculum_config)
218-
# TODO: Should be able to start learning at different lesson numbers
219-
# for each curriculum.
220-
meta_curriculum.set_all_curricula_to_lesson_num(lesson)
231+
if restore:
232+
meta_curriculum.try_restore_all_curriculum()
221233
return meta_curriculum
222234

223235

ml-agents/mlagents/trainers/meta_curriculum.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, Set
44
from mlagents.trainers.curriculum import Curriculum
55
from mlagents.trainers.settings import CurriculumSettings
6+
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
67

78
from mlagents_envs.logging_util import get_logger
89

@@ -115,16 +116,22 @@ def increment_lessons(self, measure_vals, reward_buff_sizes=None):
115116
)
116117
return ret
117118

118-
def set_all_curricula_to_lesson_num(self, lesson_num):
119-
"""Sets all the curricula in this meta curriculum to a specified
120-
lesson number.
121-
122-
Args:
123-
lesson_num (int): The lesson number which all the curricula will
124-
be set to.
119+
def try_restore_all_curriculum(self):
125120
"""
126-
for _, curriculum in self.brains_to_curricula.items():
127-
curriculum.lesson_num = lesson_num
121+
Tries to restore all the curriculums to what is saved in training_status.json
122+
"""
123+
124+
for brain_name, curriculum in self.brains_to_curricula.items():
125+
lesson_num = GlobalTrainingStatus.get_parameter_state(
126+
brain_name, StatusType.LESSON_NUM
127+
)
128+
if lesson_num is not None:
129+
logger.info(
130+
f"Resuming curriculum for {brain_name} at lesson {lesson_num}."
131+
)
132+
curriculum.lesson_num = lesson_num
133+
else:
134+
curriculum.lesson_num = 0
128135

129136
def get_config(self):
130137
"""Get the combined configuration of all curricula in this

ml-agents/mlagents/trainers/settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ class CheckpointSettings:
275275
force: bool = parser.get_default("force")
276276
train_model: bool = parser.get_default("train_model")
277277
inference: bool = parser.get_default("inference")
278-
lesson: int = parser.get_default("lesson")
279278

280279

281280
@attr.s(auto_attribs=True)

ml-agents/mlagents/trainers/tests/test_learn.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def basic_options(extra_args=None):
3131
base_port: 4001
3232
seed: 9870
3333
checkpoint_settings:
34-
lesson: 2
3534
run_id: uselessrun
3635
save_freq: 654321
3736
debug: false
@@ -120,7 +119,6 @@ def test_commandline_args(mock_file):
120119
assert opt.behaviors == {}
121120
assert opt.env_settings.env_path is None
122121
assert opt.parameter_randomization is None
123-
assert opt.checkpoint_settings.lesson == 0
124122
assert opt.checkpoint_settings.resume is False
125123
assert opt.checkpoint_settings.inference is False
126124
assert opt.checkpoint_settings.run_id == "ppo"
@@ -135,7 +133,6 @@ def test_commandline_args(mock_file):
135133
full_args = [
136134
"mytrainerpath",
137135
"--env=./myenvfile",
138-
"--lesson=3",
139136
"--resume",
140137
"--inference",
141138
"--run-id=myawesomerun",
@@ -152,7 +149,6 @@ def test_commandline_args(mock_file):
152149
assert opt.behaviors == {}
153150
assert opt.env_settings.env_path == "./myenvfile"
154151
assert opt.parameter_randomization is None
155-
assert opt.checkpoint_settings.lesson == 3
156152
assert opt.checkpoint_settings.run_id == "myawesomerun"
157153
assert opt.checkpoint_settings.save_freq == 123456
158154
assert opt.env_settings.seed == 7890
@@ -172,7 +168,6 @@ def test_yaml_args(mock_file):
172168
assert opt.behaviors == {}
173169
assert opt.env_settings.env_path == "./oldenvfile"
174170
assert opt.parameter_randomization is None
175-
assert opt.checkpoint_settings.lesson == 2
176171
assert opt.checkpoint_settings.run_id == "uselessrun"
177172
assert opt.checkpoint_settings.save_freq == 654321
178173
assert opt.env_settings.seed == 9870
@@ -185,7 +180,6 @@ def test_yaml_args(mock_file):
185180
full_args = [
186181
"mytrainerpath",
187182
"--env=./myenvfile",
188-
"--lesson=3",
189183
"--resume",
190184
"--inference",
191185
"--run-id=myawesomerun",
@@ -202,7 +196,6 @@ def test_yaml_args(mock_file):
202196
assert opt.behaviors == {}
203197
assert opt.env_settings.env_path == "./myenvfile"
204198
assert opt.parameter_randomization is None
205-
assert opt.checkpoint_settings.lesson == 3
206199
assert opt.checkpoint_settings.run_id == "myawesomerun"
207200
assert opt.checkpoint_settings.save_freq == 123456
208201
assert opt.env_settings.seed == 7890

ml-agents/mlagents/trainers/tests/test_meta_curriculum.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from unittest.mock import patch, Mock
2+
from unittest.mock import patch, Mock, call
33

44
from mlagents.trainers.meta_curriculum import MetaCurriculum
55

@@ -11,6 +11,7 @@
1111
)
1212
from mlagents.trainers.tests.test_curriculum import dummy_curriculum_config
1313
from mlagents.trainers.settings import CurriculumSettings
14+
from mlagents.trainers.training_status import StatusType
1415

1516

1617
@pytest.fixture
@@ -77,14 +78,26 @@ def test_increment_lessons_with_reward_buff_sizes(
7778
curriculum_b.increment_lesson.assert_not_called()
7879

7980

80-
def test_set_all_curriculums_to_lesson_num():
81+
@patch("mlagents.trainers.meta_curriculum.GlobalTrainingStatus")
82+
def test_restore_curriculums(mock_trainingstatus):
8183
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)
82-
83-
meta_curriculum.set_all_curricula_to_lesson_num(2)
84-
84+
# Test restore to value
85+
mock_trainingstatus.get_parameter_state.return_value = 2
86+
meta_curriculum.try_restore_all_curriculum()
87+
mock_trainingstatus.get_parameter_state.assert_has_calls(
88+
[call("Brain1", StatusType.LESSON_NUM), call("Brain2", StatusType.LESSON_NUM)],
89+
any_order=True,
90+
)
8591
assert meta_curriculum.brains_to_curricula["Brain1"].lesson_num == 2
8692
assert meta_curriculum.brains_to_curricula["Brain2"].lesson_num == 2
8793

94+
# Test restore to None
95+
mock_trainingstatus.get_parameter_state.return_value = None
96+
meta_curriculum.try_restore_all_curriculum()
97+
98+
assert meta_curriculum.brains_to_curricula["Brain1"].lesson_num == 0
99+
assert meta_curriculum.brains_to_curricula["Brain2"].lesson_num == 0
100+
88101

89102
def test_get_config():
90103
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)

0 commit comments

Comments
 (0)