diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index 03a89f8046..33bbec977b 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -152,7 +152,9 @@ def get_current_step(self): pass @abstractmethod - def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: + def checkpoint( + self, checkpoint_path: str, settings: Optional[SerializationSettings] + ) -> None: pass @abstractmethod diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index b7c460c35f..413e6ac0b9 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -417,12 +417,15 @@ def get_update_vars(self): """ return list(self.update_dict.keys()) - def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: + def checkpoint( + self, checkpoint_path: str, settings: Optional[SerializationSettings] + ) -> None: """ Checkpoints the policy on disk. :param checkpoint_path: filepath to write the checkpoint - :param settings: SerializationSettings for exporting the model. + :param settings: SerializationSettings for exporting the model. If None, + the model will not be saved. """ # Save the TF checkpoint and graph definition with self.graph.as_default(): @@ -431,8 +434,9 @@ def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> N tf.train.write_graph( self.graph, self.model_path, "raw_graph_def.pb", as_text=False ) - # also save the policy so we have optimized model files for each checkpoint - self.save(checkpoint_path, settings) + if settings is not None: + # also save the policy so we have optimized model files for each checkpoint + self.save(checkpoint_path, settings) def save(self, output_filepath: str, settings: SerializationSettings) -> None: """ diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index cdaab0de40..527d544b2c 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -76,12 +76,12 @@ def __init__( self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer - def _checkpoint(self) -> NNCheckpoint: + def _checkpoint(self, save_model: bool) -> NNCheckpoint: """ Writes a checkpoint model to memory Overrides the default to save the replay buffer. """ - ckpt = super()._checkpoint() + ckpt = super()._checkpoint(save_model) if self.checkpoint_replay_buffer: self.save_replay_buffer() return ckpt diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index 2ab44bcf1f..15a55bec50 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -96,7 +96,7 @@ def _policy_mean_reward(self) -> Optional[float]: return sum(rewards) / len(rewards) @timed - def _checkpoint(self) -> NNCheckpoint: + def _checkpoint(self, save_model: bool) -> NNCheckpoint: """ Checkpoints the policy associated with this trainer. """ @@ -107,8 +107,11 @@ def _checkpoint(self) -> NNCheckpoint: ) policy = list(self.policies.values())[0] model_path = policy.model_path - settings = SerializationSettings(model_path, self.brain_name) checkpoint_path = os.path.join(model_path, f"{self.brain_name}-{self.step}") + # Don't pass SerializationSettings if we're not going to save the model. + settings = ( + SerializationSettings(model_path, self.brain_name) if save_model else None + ) policy.checkpoint(checkpoint_path, settings) new_checkpoint = NNCheckpoint( int(self.step), @@ -132,7 +135,7 @@ def save_model(self) -> None: ) policy = list(self.policies.values())[0] settings = SerializationSettings(policy.model_path, self.brain_name) - model_checkpoint = self._checkpoint() + model_checkpoint = self._checkpoint(save_model=False) final_checkpoint = attr.evolve( model_checkpoint, file_path=f"{policy.model_path}.nn" ) @@ -207,7 +210,7 @@ def _maybe_save_model(self, step_after_process: int) -> None: self.trainer_settings.checkpoint_interval ) if step_after_process >= self._next_save_step and self.get_step != 0: - self._checkpoint() + self._checkpoint(save_model=True) def advance(self) -> None: """