diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index 63af614370..c308881a2f 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -70,7 +70,7 @@ def run_training(run_seed: int, options: RunOptions) -> None: base_path = "results" write_path = os.path.join(base_path, checkpoint_settings.run_id) maybe_init_path = ( - os.path.join(base_path, checkpoint_settings.run_id) + os.path.join(base_path, checkpoint_settings.initialize_from) if checkpoint_settings.initialize_from else None ) diff --git a/ml-agents/mlagents/trainers/tests/test_learn.py b/ml-agents/mlagents/trainers/tests/test_learn.py index daacc1c4d3..5a67036ade 100644 --- a/ml-agents/mlagents/trainers/tests/test_learn.py +++ b/ml-agents/mlagents/trainers/tests/test_learn.py @@ -22,6 +22,13 @@ def basic_options(extra_args=None): {} """ +MOCK_INITIALIZE_YAML = """ + behaviors: + {} + checkpoint_settings: + initialize_from: notuselessrun + """ + MOCK_PARAMETER_YAML = """ behaviors: {} @@ -32,6 +39,7 @@ def basic_options(extra_args=None): seed: 9870 checkpoint_settings: run_id: uselessrun + initialize_from: notuselessrun debug: false """ @@ -71,7 +79,7 @@ def test_run_training( mock_env.external_brain_names = [] mock_env.academy_name = "TestAcademyName" create_environment_factory.return_value = mock_env - load_config.return_value = yaml.safe_load(MOCK_YAML) + load_config.return_value = yaml.safe_load(MOCK_INITIALIZE_YAML) mock_init = MagicMock(return_value=None) with patch.object(TrainerController, "__init__", mock_init): @@ -88,7 +96,9 @@ def test_run_training( sampler_manager_mock.return_value, None, ) - handle_dir_mock.assert_called_once_with("results/ppo", False, False, None) + handle_dir_mock.assert_called_once_with( + "results/ppo", False, False, "results/notuselessrun" + ) write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs") write_run_options_mock.assert_called_once_with("results/ppo", options) StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py @@ -120,6 +130,7 @@ def test_commandline_args(mock_file): assert opt.checkpoint_settings.resume is False assert opt.checkpoint_settings.inference is False assert opt.checkpoint_settings.run_id == "ppo" + assert opt.checkpoint_settings.initialize_from is None assert opt.env_settings.seed == -1 assert opt.env_settings.base_port == 5005 assert opt.env_settings.num_envs == 1 @@ -136,6 +147,7 @@ def test_commandline_args(mock_file): "--seed=7890", "--train", "--base-port=4004", + "--initialize-from=testdir", "--num-envs=2", "--no-graphics", "--debug", @@ -146,6 +158,7 @@ def test_commandline_args(mock_file): assert opt.env_settings.env_path == "./myenvfile" assert opt.parameter_randomization is None assert opt.checkpoint_settings.run_id == "myawesomerun" + assert opt.checkpoint_settings.initialize_from == "testdir" assert opt.env_settings.seed == 7890 assert opt.env_settings.base_port == 4004 assert opt.env_settings.num_envs == 2 @@ -164,6 +177,7 @@ def test_yaml_args(mock_file): assert opt.env_settings.env_path == "./oldenvfile" assert opt.parameter_randomization is None assert opt.checkpoint_settings.run_id == "uselessrun" + assert opt.checkpoint_settings.initialize_from == "notuselessrun" assert opt.env_settings.seed == 9870 assert opt.env_settings.base_port == 4001 assert opt.env_settings.num_envs == 4