diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 240e5bd7f5..d2b402c8f8 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -33,6 +33,7 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
+- The `--save-freq` CLI option has been removed, and replaced by a `checkpoint_interval` option in the trainer configuration YAML. (#4034)
- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents,
a warning will be thrown. (#4035)
### Bug Fixes
diff --git a/docs/Migrating.md b/docs/Migrating.md
index f037bf0948..a80ea86162 100644
--- a/docs/Migrating.md
+++ b/docs/Migrating.md
@@ -28,6 +28,8 @@ double-check that the versions are in the same. The versions can be found in
- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
were replaced by `allow_multiple_obs` which allows one or more visual observations and
vector observations to be used simultaneously.
+- `--save-freq` has been removed from the CLI and is now configurable in the trainer configuration
+ file.
- `--lesson` has been removed from the CLI. Lessons will resume when using `--resume`.
To start at a different lesson, modify your Curriculum configuration.
@@ -49,6 +51,8 @@ vector observations to be used simultaneously.
- If you use the `UnityToGymWrapper`, remove `use_visual` and `allow_multiple_visual_obs`
from the constructor and add `allow_multiple_obs = True` if the environment contains either
both visual and vector observations or multiple visual observations.
+ - If you were setting `--save-freq` in the CLI, add a `checkpoint_interval` value in your
+ trainer configuration, and set it equal to `save-freq * n_agents_in_scene`.
## Migrating from 0.15 to Release 1
diff --git a/docs/Training-Configuration-File.md b/docs/Training-Configuration-File.md
index 6df42c6ade..16b384d167 100644
--- a/docs/Training-Configuration-File.md
+++ b/docs/Training-Configuration-File.md
@@ -30,7 +30,8 @@ choice of the trainer (which we review on subsequent sections).
| `summary_freq` | (default = `50000`) Number of experiences that needs to be collected before generating and displaying training statistics. This determines the granularity of the graphs in Tensorboard. |
| `time_horizon` | (default = `64`) How many steps of experience to collect per-agent before adding it to the experience buffer. When this limit is reached before the end of an episode, a value estimate is used to predict the overall expected reward from the agent's current state. As such, this parameter trades off between a less biased, but higher variance estimate (long time horizon) and more biased, but less varied estimate (short time horizon). In cases where there are frequent rewards within an episode, or episodes are prohibitively large, a smaller number can be more ideal. This number should be large enough to capture all the important behavior within a sequence of an agent's actions.
Typical range: `32` - `2048` |
| `max_steps` | (default = `500000`) Total number of steps (i.e., observation collected and action taken) that must be taken in the environment (or across all environments if using multiple in parallel) before ending the training process. If you have multiple agents with the same behavior name within your environment, all steps taken by those agents will contribute to the same `max_steps` count.
Typical range: `5e5` - `1e7` |
-| `keep_checkpoints` | (default = `5`) The maximum number of model checkpoints to keep. Checkpoints are saved after the number of steps specified by the save-freq option. Once the maximum number of checkpoints has been reached, the oldest checkpoint is deleted when saving a new checkpoint. |
+| `keep_checkpoints` | (default = `5`) The maximum number of model checkpoints to keep. Checkpoints are saved after the number of steps specified by the checkpoint_interval option. Once the maximum number of checkpoints has been reached, the oldest checkpoint is deleted when saving a new checkpoint. |
+| `checkpoint_interval` | (default = `500000`) The number of experiences collected between each checkpoint by the trainer. A maximum of `keep_checkpoints` checkpoints are saved before old ones are deleted. |
| `init_path` | (default = None) Initialize trainer from a previously saved model. Note that the prior run should have used the same trainer configurations as the current run, and have been saved with the same version of ML-Agents.
You should provide the full path to the folder where the checkpoints were saved, e.g. `./models/{run-id}/{behavior_name}`. This option is provided in case you want to initialize different behaviors from different runs; in most cases, it is sufficient to use the `--initialize-from` CLI parameter to initialize all models from the same run. |
| `threaded` | (default = `true`) By default, model updates can happen while the environment is being stepped. This violates the [on-policy](https://spinningup.openai.com/en/latest/user/algorithms.html#the-on-policy-algorithms) assumption of PPO slightly in exchange for a training speedup. To maintain the strict on-policyness of PPO, you can disable parallel updates by setting `threaded` to `false`. There is usually no reason to turn `threaded` off for SAC. |
| `hyperparameters -> learning_rate` | (default = `3e-4`) Initial learning rate for gradient descent. Corresponds to the strength of each gradient descent update step. This should typically be decreased if training is unstable, and the reward does not consistently increase.
Typical range: `1e-5` - `1e-3` |
diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md
index 22fcd58f53..3d6b87b329 100644
--- a/docs/Training-ML-Agents.md
+++ b/docs/Training-ML-Agents.md
@@ -231,6 +231,7 @@ behaviors:
time_horizon: 64
summary_freq: 10000
keep_checkpoints: 5
+ checkpoint_interval: 50000
threaded: true
init_path: null
diff --git a/docs/Using-Tensorboard.md b/docs/Using-Tensorboard.md
index f3c1cde367..b2a54a0b2b 100644
--- a/docs/Using-Tensorboard.md
+++ b/docs/Using-Tensorboard.md
@@ -29,9 +29,6 @@ runs you want to display. You can select multiple run-ids to compare statistics.
The TensorBoard window also provides options for how to display and smooth
graphs.
-When you run the training program, `mlagents-learn`, you can use the
-`--save-freq` option to specify how frequently to save the statistics.
-
## The ML-Agents Toolkit training statistics
The ML-Agents training program saves the following statistics:
diff --git a/ml-agents/mlagents/trainers/cli_utils.py b/ml-agents/mlagents/trainers/cli_utils.py
index dfbc18c9a2..44624e6657 100644
--- a/ml-agents/mlagents/trainers/cli_utils.py
+++ b/ml-agents/mlagents/trainers/cli_utils.py
@@ -105,13 +105,6 @@ def _create_parser() -> argparse.ArgumentParser:
"current environment.",
action=DetectDefault,
)
- argparser.add_argument(
- "--save-freq",
- default=50000,
- type=int,
- help="How often (in steps) to save the model during training",
- action=DetectDefault,
- )
argparser.add_argument(
"--seed",
default=-1,
diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py
index 8e81ad5154..48c5a1f76a 100644
--- a/ml-agents/mlagents/trainers/ghost/trainer.py
+++ b/ml-agents/mlagents/trainers/ghost/trainer.py
@@ -240,7 +240,7 @@ def advance(self) -> None:
except AgentManagerQueue.Empty:
pass
- self.next_summary_step = self.trainer.next_summary_step
+ self._next_summary_step = self.trainer._next_summary_step
self.trainer.advance()
if self.get_step - self.last_team_change > self.steps_to_train_team:
self.controller.change_training_team(self.get_step)
diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py
index b6ff71bc69..63af614370 100644
--- a/ml-agents/mlagents/trainers/learn.py
+++ b/ml-agents/mlagents/trainers/learn.py
@@ -152,7 +152,6 @@ def run_training(run_seed: int, options: RunOptions) -> None:
trainer_factory,
write_path,
checkpoint_settings.run_id,
- checkpoint_settings.save_freq,
maybe_meta_curriculum,
not checkpoint_settings.inference,
run_seed,
diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py
index 743091b770..5abb5774af 100644
--- a/ml-agents/mlagents/trainers/ppo/trainer.py
+++ b/ml-agents/mlagents/trainers/ppo/trainer.py
@@ -253,7 +253,6 @@ def add_policy(
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
- self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""
diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py
index fda8f8ca1a..848362a29c 100644
--- a/ml-agents/mlagents/trainers/sac/trainer.py
+++ b/ml-agents/mlagents/trainers/sac/trainer.py
@@ -333,7 +333,6 @@ def add_policy(
self.reward_signal_update_steps = int(
max(1, self.step / self.reward_signal_steps_per_update)
)
- self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""
diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py
index cc537e4edd..d3ea4205ca 100644
--- a/ml-agents/mlagents/trainers/settings.py
+++ b/ml-agents/mlagents/trainers/settings.py
@@ -192,6 +192,7 @@ def _set_default_hyperparameters(self):
init_path: Optional[str] = None
output_path: str = "default"
keep_checkpoints: int = 5
+ checkpoint_interval: int = 500000
max_steps: int = 500000
time_horizon: int = 64
summary_freq: int = 50000
@@ -267,7 +268,6 @@ class MeasureType:
@attr.s(auto_attribs=True)
class CheckpointSettings:
- save_freq: int = parser.get_default("save_freq")
run_id: str = parser.get_default("run_id")
initialize_from: str = parser.get_default("initialize_from")
load_model: bool = parser.get_default("load_model")
diff --git a/ml-agents/mlagents/trainers/tests/test_learn.py b/ml-agents/mlagents/trainers/tests/test_learn.py
index a2b72a66c9..daacc1c4d3 100644
--- a/ml-agents/mlagents/trainers/tests/test_learn.py
+++ b/ml-agents/mlagents/trainers/tests/test_learn.py
@@ -32,7 +32,6 @@ def basic_options(extra_args=None):
seed: 9870
checkpoint_settings:
run_id: uselessrun
- save_freq: 654321
debug: false
"""
@@ -83,7 +82,6 @@ def test_run_training(
trainer_factory_mock.return_value,
"results/ppo",
"ppo",
- 50000,
None,
True,
0,
@@ -122,7 +120,6 @@ def test_commandline_args(mock_file):
assert opt.checkpoint_settings.resume is False
assert opt.checkpoint_settings.inference is False
assert opt.checkpoint_settings.run_id == "ppo"
- assert opt.checkpoint_settings.save_freq == 50000
assert opt.env_settings.seed == -1
assert opt.env_settings.base_port == 5005
assert opt.env_settings.num_envs == 1
@@ -136,7 +133,6 @@ def test_commandline_args(mock_file):
"--resume",
"--inference",
"--run-id=myawesomerun",
- "--save-freq=123456",
"--seed=7890",
"--train",
"--base-port=4004",
@@ -150,7 +146,6 @@ def test_commandline_args(mock_file):
assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
- assert opt.checkpoint_settings.save_freq == 123456
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2
@@ -169,7 +164,6 @@ def test_yaml_args(mock_file):
assert opt.env_settings.env_path == "./oldenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "uselessrun"
- assert opt.checkpoint_settings.save_freq == 654321
assert opt.env_settings.seed == 9870
assert opt.env_settings.base_port == 4001
assert opt.env_settings.num_envs == 4
@@ -183,7 +177,6 @@ def test_yaml_args(mock_file):
"--resume",
"--inference",
"--run-id=myawesomerun",
- "--save-freq=123456",
"--seed=7890",
"--train",
"--base-port=4004",
@@ -197,7 +190,6 @@ def test_yaml_args(mock_file):
assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
- assert opt.checkpoint_settings.save_freq == 123456
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2
diff --git a/ml-agents/mlagents/trainers/tests/test_ppo.py b/ml-agents/mlagents/trainers/tests/test_ppo.py
index 29247cc1a4..11af4549a7 100644
--- a/ml-agents/mlagents/trainers/tests/test_ppo.py
+++ b/ml-agents/mlagents/trainers/tests/test_ppo.py
@@ -351,7 +351,6 @@ def test_add_get_policy(ppo_optimizer, dummy_config):
# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
- assert trainer.next_summary_step > 2000
# Test incorrect class of policy
policy = mock.Mock()
diff --git a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
index 854bb99266..ee2847cfa2 100644
--- a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
+++ b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
@@ -43,7 +43,12 @@ def _process_trajectory(self, trajectory):
def create_rl_trainer():
mock_brainparams = create_mock_brain()
- trainer = FakeTrainer(mock_brainparams, TrainerSettings(max_steps=100), True, 0)
+ trainer = FakeTrainer(
+ mock_brainparams,
+ TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20),
+ True,
+ 0,
+ )
trainer.set_is_policy_updating(True)
return trainer
@@ -107,3 +112,45 @@ def test_advance(mocked_clear_update_buffer):
# Check that the buffer has been cleared
assert not trainer.should_still_train
assert mocked_clear_update_buffer.call_count > 0
+
+
+@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
+@mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats")
+def test_summary_checkpoint(mock_write_summary, mock_save_model):
+ trainer = create_rl_trainer()
+ trajectory_queue = AgentManagerQueue("testbrain")
+ policy_queue = AgentManagerQueue("testbrain")
+ trainer.subscribe_trajectory_queue(trajectory_queue)
+ trainer.publish_policy_queue(policy_queue)
+ time_horizon = 10
+ summary_freq = trainer.trainer_settings.summary_freq
+ checkpoint_interval = trainer.trainer_settings.checkpoint_interval
+ trajectory = mb.make_fake_trajectory(
+ length=time_horizon,
+ max_step_complete=True,
+ vec_obs_size=1,
+ num_vis_obs=0,
+ action_space=[2],
+ )
+ # Check that we can turn off the trainer and that the buffer is cleared
+ num_trajectories = 5
+ for _ in range(0, num_trajectories):
+ trajectory_queue.put(trajectory)
+ trainer.advance()
+ # Check that there is stuff in the policy queue
+ policy_queue.get_nowait()
+
+ # Check that we have called write_summary the appropriate number of times
+ calls = [
+ mock.call(step)
+ for step in range(summary_freq, num_trajectories * time_horizon, summary_freq)
+ ]
+ mock_write_summary.assert_has_calls(calls, any_order=True)
+
+ calls = [
+ mock.call(trainer.brain_name)
+ for step in range(
+ checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
+ )
+ ]
+ mock_save_model.assert_has_calls(calls, any_order=True)
diff --git a/ml-agents/mlagents/trainers/tests/test_sac.py b/ml-agents/mlagents/trainers/tests/test_sac.py
index 47de1bd32e..8b8d21bdbe 100644
--- a/ml-agents/mlagents/trainers/tests/test_sac.py
+++ b/ml-agents/mlagents/trainers/tests/test_sac.py
@@ -138,7 +138,6 @@ def test_add_get_policy(sac_optimizer, dummy_config):
# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
- assert trainer.next_summary_step > 2000
# Test incorrect class of policy
policy = mock.Mock()
diff --git a/ml-agents/mlagents/trainers/tests/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/test_simple_rl.py
index 4a4996518c..907e27a189 100644
--- a/ml-agents/mlagents/trainers/tests/test_simple_rl.py
+++ b/ml-agents/mlagents/trainers/tests/test_simple_rl.py
@@ -115,7 +115,6 @@ def _check_environment_trains(
# Create controller and begin training.
with tempfile.TemporaryDirectory() as dir:
run_id = "id"
- save_freq = 99999
seed = 1337
StatsReporter.writers.clear() # Clear StatsReporters so we don't write to file
debug_writer = DebugWriter()
@@ -142,7 +141,6 @@ def _check_environment_trains(
training_seed=seed,
sampler_manager=SamplerManager(None),
resampling_interval=None,
- save_freq=save_freq,
)
# Begin training
diff --git a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
index 031241642f..cf2f872531 100644
--- a/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
+++ b/ml-agents/mlagents/trainers/tests/test_trainer_controller.py
@@ -15,7 +15,6 @@ def basic_trainer_controller():
trainer_factory=trainer_factory_mock,
output_path="test_model_path",
run_id="test_run_id",
- save_freq=100,
meta_curriculum=None,
train=True,
training_seed=99,
@@ -34,7 +33,6 @@ def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
trainer_factory=trainer_factory_mock,
output_path="",
run_id="1",
- save_freq=1,
meta_curriculum=None,
train=True,
training_seed=seed,
diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py
index 7dfb4163a1..8872347722 100644
--- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py
+++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py
@@ -4,6 +4,7 @@
import abc
import time
+from mlagents_envs.logging_util import get_logger
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer
@@ -15,6 +16,8 @@
RewardSignalResults = Dict[str, RewardSignalResult]
+logger = get_logger(__name__)
+
class RLTrainer(Trainer): # pylint: disable=abstract-method
"""
@@ -34,6 +37,8 @@ def __init__(self, *args, **kwargs):
self._stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
)
+ self._next_save_step = 0
+ self._next_summary_step = 0
def end_episode(self) -> None:
"""
@@ -89,16 +94,20 @@ def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
- self.next_summary_step = self._get_next_summary_step()
+ self._next_summary_step = self._get_next_interval_step(self.summary_freq)
+ self._next_save_step = self._get_next_interval_step(
+ self.trainer_settings.checkpoint_interval
+ )
p = self.get_policy(name_behavior_id)
if p:
p.increment_step(n_steps)
- def _get_next_summary_step(self) -> int:
+ def _get_next_interval_step(self, interval: int) -> int:
"""
- Get the next step count that should result in a summary write.
+ Get the next step count that should result in an action.
+ :param interval: The interval between actions.
"""
- return self.step + (self.summary_freq - self.step % self.summary_freq)
+ return self.step + (interval - self.step % interval)
def _write_summary(self, step: int) -> None:
"""
@@ -114,6 +123,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
+ self._maybe_save_model(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:
@@ -122,8 +132,24 @@ def _maybe_write_summary(self, step_after_process: int) -> None:
write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
- if step_after_process >= self.next_summary_step and self.get_step != 0:
- self._write_summary(self.next_summary_step)
+ if self._next_summary_step == 0: # Don't write out the first one
+ self._next_summary_step = self._get_next_interval_step(self.summary_freq)
+ if step_after_process >= self._next_summary_step and self.get_step != 0:
+ self._write_summary(self._next_summary_step)
+
+ def _maybe_save_model(self, step_after_process: int) -> None:
+ """
+ If processing the trajectory will make the step exceed the next model write,
+ save the model. This logic ensures models are written on the update step and not in between.
+ :param step_after_process: the step count after processing the next trajectory.
+ """
+ if self._next_save_step == 0: # Don't save the first one
+ self._next_save_step = self._get_next_interval_step(
+ self.trainer_settings.checkpoint_interval
+ )
+ if step_after_process >= self._next_save_step and self.get_step != 0:
+ logger.info(f"Checkpointing model for {self.brain_name}.")
+ self.save_model(self.brain_name)
def advance(self) -> None:
"""
diff --git a/ml-agents/mlagents/trainers/trainer/trainer.py b/ml-agents/mlagents/trainers/trainer/trainer.py
index fcad7ed5a7..1aec75e851 100644
--- a/ml-agents/mlagents/trainers/trainer/trainer.py
+++ b/ml-agents/mlagents/trainers/trainer/trainer.py
@@ -5,6 +5,7 @@
from collections import deque
from mlagents_envs.logging_util import get_logger
+from mlagents_envs.timers import timed
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter
@@ -50,7 +51,6 @@ def __init__(
self.trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
self.step: int = 0
self.summary_freq = self.trainer_settings.summary_freq
- self.next_summary_step = self.summary_freq
@property
def stats_reporter(self):
@@ -110,6 +110,7 @@ def reward_buffer(self) -> Deque[float]:
"""
return self._reward_buffer
+ @timed
def save_model(self, name_behavior_id: str) -> None:
"""
Saves the model
diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py
index 390ae55872..3a8a74d15f 100644
--- a/ml-agents/mlagents/trainers/trainer_controller.py
+++ b/ml-agents/mlagents/trainers/trainer_controller.py
@@ -39,7 +39,6 @@ def __init__(
trainer_factory: TrainerFactory,
output_path: str,
run_id: str,
- save_freq: int,
meta_curriculum: Optional[MetaCurriculum],
train: bool,
training_seed: int,
@@ -50,7 +49,6 @@ def __init__(
:param output_path: Path to save the model.
:param summaries_dir: Folder to save training summaries.
:param run_id: The sub-directory name for model and summary statistics
- :param save_freq: Frequency at which to save model
:param meta_curriculum: MetaCurriculum object which stores information about all curricula.
:param train: Whether to train model, or only run inference.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
@@ -64,7 +62,6 @@ def __init__(
self.output_path = output_path
self.logger = get_logger(__name__)
self.run_id = run_id
- self.save_freq = save_freq
self.train_model = train
self.meta_curriculum = meta_curriculum
self.sampler_manager = sampler_manager
@@ -152,11 +149,6 @@ def _reset_env(self, env: EnvManager) -> None:
sampled_reset_param.update(new_meta_curriculum_config)
env.reset(config=sampled_reset_param)
- def _should_save_model(self, global_step: int) -> bool:
- return (
- global_step % self.save_freq == 0 and global_step != 0 and self.train_model
- )
-
def _not_done_training(self) -> bool:
return (
any(t.should_still_train for t in self.trainers.values())
@@ -229,13 +221,8 @@ def start_learning(self, env_manager: EnvManager) -> None:
for _ in range(n_steps):
global_step += 1
self.reset_env_if_ready(env_manager, global_step)
- if self._should_save_model(global_step):
- self._save_model()
# Stop advancing trainers
self.join_threads()
- # Final save Tensorflow model
- if global_step != 0 and self.train_model:
- self._save_model()
except (
KeyboardInterrupt,
UnityCommunicationException,
@@ -243,9 +230,9 @@ def start_learning(self, env_manager: EnvManager) -> None:
UnityCommunicatorStoppedException,
) as ex:
self.join_threads()
- if self.train_model:
- self._save_model_when_interrupted()
-
+ self.logger.info(
+ "Learning was interrupted. Please wait while the graph is generated."
+ )
if isinstance(ex, KeyboardInterrupt) or isinstance(
ex, UnityCommunicatorStoppedException
):
@@ -256,6 +243,7 @@ def start_learning(self, env_manager: EnvManager) -> None:
raise ex
finally:
if self.train_model:
+ self._save_model()
self._export_graph()
def end_trainer_episodes(