Skip to content

Commit 549bea2

Browse files
committed
Fix lesson incrementing (#4279)
1 parent 2f72c2f commit 549bea2

File tree

4 files changed

+93
-5
lines changed

4 files changed

+93
-5
lines changed

ml-agents/mlagents/trainers/environment_parameter_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def update_lessons(
131131
lesson = settings.curriculum[lesson_num]
132132
if (
133133
lesson.completion_criteria is not None
134-
and len(settings.curriculum) > lesson_num
134+
and len(settings.curriculum) > lesson_num + 1
135135
):
136136
behavior_to_consider = lesson.completion_criteria.behavior
137137
if behavior_to_consider in trainer_steps:

ml-agents/mlagents/trainers/exception.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ class TrainerConfigError(Exception):
1919
pass
2020

2121

22+
class TrainerConfigWarning(Warning):
23+
"""
24+
Any warning related to the configuration of trainers in the ML-Agents Toolkit.
25+
"""
26+
27+
pass
28+
29+
2230
class CurriculumError(TrainerError):
2331
"""
2432
Any error related to training with a curriculum.

ml-agents/mlagents/trainers/settings.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import attr
24
import cattr
35
from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple, Union
@@ -10,8 +12,7 @@
1012

1113
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser
1214
from mlagents.trainers.cli_utils import load_config
13-
from mlagents.trainers.exception import TrainerConfigError
14-
from mlagents.trainers.models import ScheduleType, EncoderType
15+
from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning
1516

1617
from mlagents_envs import logging_util
1718
from mlagents_envs.side_channel.environment_parameters_channel import (
@@ -51,6 +52,17 @@ def as_dict(self):
5152
return cattr.unstructure(self)
5253

5354

55+
class EncoderType(Enum):
56+
SIMPLE = "simple"
57+
NATURE_CNN = "nature_cnn"
58+
RESNET = "resnet"
59+
60+
61+
class ScheduleType(Enum):
62+
CONSTANT = "constant"
63+
LINEAR = "linear"
64+
65+
5466
@attr.s(auto_attribs=True)
5567
class NetworkSettings:
5668
@attr.s
@@ -433,14 +445,20 @@ class EnvironmentParameterSettings:
433445
def _check_lesson_chain(lessons, parameter_name):
434446
"""
435447
Ensures that when using curriculum, all non-terminal lessons have a valid
436-
CompletionCriteria
448+
CompletionCriteria, and that the terminal lesson does not contain a CompletionCriteria.
437449
"""
438450
num_lessons = len(lessons)
439451
for index, lesson in enumerate(lessons):
440452
if index < num_lessons - 1 and lesson.completion_criteria is None:
441453
raise TrainerConfigError(
442454
f"A non-terminal lesson does not have a completion_criteria for {parameter_name}."
443455
)
456+
if index == num_lessons - 1 and lesson.completion_criteria is not None:
457+
warnings.warn(
458+
f"Your final lesson definition contains completion_criteria for {parameter_name}."
459+
f"It will be ignored.",
460+
TrainerConfigWarning,
461+
)
444462

445463
@staticmethod
446464
def structure(d: Mapping, t: type) -> Dict[str, "EnvironmentParameterSettings"]:

ml-agents/mlagents/trainers/tests/test_env_param_manager.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import yaml
33

44

5-
from mlagents.trainers.exception import TrainerConfigError
5+
from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning
66
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
77
from mlagents.trainers.settings import (
88
RunOptions,
@@ -154,13 +154,75 @@ def test_curriculum_conversion():
154154
"""
155155

156156

157+
test_bad_curriculum_all_competion_criteria_config_yaml = """
158+
environment_parameters:
159+
param_1:
160+
curriculum:
161+
- name: Lesson1
162+
completion_criteria:
163+
measure: reward
164+
behavior: fake_behavior
165+
threshold: 30
166+
min_lesson_length: 100
167+
require_reset: true
168+
value: 1
169+
- name: Lesson2
170+
completion_criteria:
171+
measure: reward
172+
behavior: fake_behavior
173+
threshold: 30
174+
min_lesson_length: 100
175+
require_reset: true
176+
value: 2
177+
- name: Lesson3
178+
completion_criteria:
179+
measure: reward
180+
behavior: fake_behavior
181+
threshold: 30
182+
min_lesson_length: 100
183+
require_reset: true
184+
value:
185+
sampler_type: uniform
186+
sampler_parameters:
187+
min_value: 1
188+
max_value: 3
189+
"""
190+
191+
157192
def test_curriculum_raises_no_completion_criteria_conversion():
158193
with pytest.raises(TrainerConfigError):
159194
RunOptions.from_dict(
160195
yaml.safe_load(test_bad_curriculum_no_competion_criteria_config_yaml)
161196
)
162197

163198

199+
def test_curriculum_raises_all_completion_criteria_conversion():
200+
with pytest.warns(TrainerConfigWarning):
201+
run_options = RunOptions.from_dict(
202+
yaml.safe_load(test_bad_curriculum_all_competion_criteria_config_yaml)
203+
)
204+
205+
param_manager = EnvironmentParameterManager(
206+
run_options.environment_parameters, 1337, False
207+
)
208+
assert param_manager.update_lessons(
209+
trainer_steps={"fake_behavior": 500},
210+
trainer_max_steps={"fake_behavior": 1000},
211+
trainer_reward_buffer={"fake_behavior": [1000] * 101},
212+
) == (True, True)
213+
assert param_manager.update_lessons(
214+
trainer_steps={"fake_behavior": 500},
215+
trainer_max_steps={"fake_behavior": 1000},
216+
trainer_reward_buffer={"fake_behavior": [1000] * 101},
217+
) == (True, True)
218+
assert param_manager.update_lessons(
219+
trainer_steps={"fake_behavior": 500},
220+
trainer_max_steps={"fake_behavior": 1000},
221+
trainer_reward_buffer={"fake_behavior": [1000] * 101},
222+
) == (False, False)
223+
assert param_manager.get_current_lesson_number() == {"param_1": 2}
224+
225+
164226
test_everything_config_yaml = """
165227
environment_parameters:
166228
param_1:

0 commit comments

Comments
 (0)