diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 84277ed2cc..f7012b3e65 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -32,6 +32,8 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes ! - Unity Player logs are now written out to the results directory. (#3877) - Run configuration YAML files are written out to the results directory at the end of the run. (#3815) ### Bug Fixes +- Fixed an issue where SAC would perform too many model updates when resuming from a + checkpoint, and too few when using `buffer_init_steps`. (#4038) #### com.unity.ml-agents (C#) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index 9b321ff500..397747a32a 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -65,9 +65,9 @@ def __init__( ) self.step = 0 - # Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0 - self.update_steps = max(1, self.hyperparameters.buffer_init_steps) - self.reward_signal_update_steps = max(1, self.hyperparameters.buffer_init_steps) + # Don't divide by zero + self.update_steps = 1 + self.reward_signal_update_steps = 1 self.steps_per_update = self.hyperparameters.steps_per_update self.reward_signal_steps_per_update = ( @@ -229,7 +229,9 @@ def _update_sac_policy(self) -> bool: ) batch_update_stats: Dict[str, list] = defaultdict(list) - while self.step / self.update_steps > self.steps_per_update: + while ( + self.step - self.hyperparameters.buffer_init_steps + ) / self.update_steps > self.steps_per_update: logger.debug("Updating SAC policy at step {}".format(self.step)) buffer = self.update_buffer if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: @@ -282,9 +284,8 @@ def _update_reward_signals(self) -> None: ) batch_update_stats: Dict[str, list] = defaultdict(list) while ( - self.step / self.reward_signal_update_steps - > self.reward_signal_steps_per_update - ): + self.step - self.hyperparameters.buffer_init_steps + ) / self.reward_signal_update_steps > self.reward_signal_steps_per_update: # Get minibatches for reward signal update if needed reward_signal_minibatches = {} for name, signal in self.optimizer.reward_signals.items(): @@ -327,6 +328,11 @@ def add_policy( self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) # Needed to resume loads properly self.step = policy.get_current_step() + # Assume steps were updated at the correct ratio before + self.update_steps = int(max(1, self.step / self.steps_per_update)) + self.reward_signal_update_steps = int( + max(1, self.step / self.reward_signal_steps_per_update) + ) self.next_summary_step = self._get_next_summary_step() def get_policy(self, name_behavior_id: str) -> TFPolicy: diff --git a/ml-agents/mlagents/trainers/tests/test_sac.py b/ml-agents/mlagents/trainers/tests/test_sac.py index 62fa154cc3..47de1bd32e 100644 --- a/ml-agents/mlagents/trainers/tests/test_sac.py +++ b/ml-agents/mlagents/trainers/tests/test_sac.py @@ -151,6 +151,7 @@ def test_advance(dummy_config): discrete_action=False, visual_inputs=0, vec_obs_size=6 ) dummy_config.hyperparameters.steps_per_update = 20 + dummy_config.hyperparameters.reward_signal_steps_per_update = 20 dummy_config.hyperparameters.buffer_init_steps = 0 trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0") policy = trainer.create_policy(brain_params.brain_name, brain_params) @@ -220,6 +221,21 @@ def test_advance(dummy_config): with pytest.raises(AgentManagerQueue.Empty): policy_queue.get_nowait() + # Call add_policy and check that we update the correct number of times. + # This is to emulate a load from checkpoint. + policy = trainer.create_policy(brain_params.brain_name, brain_params) + policy.get_current_step = lambda: 200 + trainer.add_policy(brain_params.brain_name, policy) + trainer.optimizer.update = mock.Mock() + trainer.optimizer.update_reward_signals = mock.Mock() + trainer.optimizer.update_reward_signals.return_value = {} + trainer.optimizer.update.return_value = {} + trajectory_queue.put(trajectory) + trainer.advance() + # Make sure we did exactly 1 update + assert trainer.optimizer.update.call_count == 1 + assert trainer.optimizer.update_reward_signals.call_count == 1 + if __name__ == "__main__": pytest.main()