Skip to content

[bug-fix] Fix regression in --initialize-from feature #4086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
base_path = "results"
write_path = os.path.join(base_path, checkpoint_settings.run_id)
maybe_init_path = (
os.path.join(base_path, checkpoint_settings.run_id)
os.path.join(base_path, checkpoint_settings.initialize_from)
if checkpoint_settings.initialize_from
else None
)
Expand Down
18 changes: 16 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def basic_options(extra_args=None):
{}
"""

MOCK_INITIALIZE_YAML = """
behaviors:
{}
checkpoint_settings:
initialize_from: notuselessrun
"""

MOCK_PARAMETER_YAML = """
behaviors:
{}
Expand All @@ -32,6 +39,7 @@ def basic_options(extra_args=None):
seed: 9870
checkpoint_settings:
run_id: uselessrun
initialize_from: notuselessrun
debug: false
"""

Expand Down Expand Up @@ -71,7 +79,7 @@ def test_run_training(
mock_env.external_brain_names = []
mock_env.academy_name = "TestAcademyName"
create_environment_factory.return_value = mock_env
load_config.return_value = yaml.safe_load(MOCK_YAML)
load_config.return_value = yaml.safe_load(MOCK_INITIALIZE_YAML)

mock_init = MagicMock(return_value=None)
with patch.object(TrainerController, "__init__", mock_init):
Expand All @@ -88,7 +96,9 @@ def test_run_training(
sampler_manager_mock.return_value,
None,
)
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None)
handle_dir_mock.assert_called_once_with(
"results/ppo", False, False, "results/notuselessrun"
)
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs")
write_run_options_mock.assert_called_once_with("results/ppo", options)
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py
Expand Down Expand Up @@ -120,6 +130,7 @@ def test_commandline_args(mock_file):
assert opt.checkpoint_settings.resume is False
assert opt.checkpoint_settings.inference is False
assert opt.checkpoint_settings.run_id == "ppo"
assert opt.checkpoint_settings.initialize_from is None
assert opt.env_settings.seed == -1
assert opt.env_settings.base_port == 5005
assert opt.env_settings.num_envs == 1
Expand All @@ -136,6 +147,7 @@ def test_commandline_args(mock_file):
"--seed=7890",
"--train",
"--base-port=4004",
"--initialize-from=testdir",
"--num-envs=2",
"--no-graphics",
"--debug",
Expand All @@ -146,6 +158,7 @@ def test_commandline_args(mock_file):
assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.initialize_from == "testdir"
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2
Expand All @@ -164,6 +177,7 @@ def test_yaml_args(mock_file):
assert opt.env_settings.env_path == "./oldenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "uselessrun"
assert opt.checkpoint_settings.initialize_from == "notuselessrun"
assert opt.env_settings.seed == 9870
assert opt.env_settings.base_port == 4001
assert opt.env_settings.num_envs == 4
Expand Down