Skip to content

[bug-fix] Fix issue with SAC updating too much on resume #4038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
### Bug Fixes
- Fixed an issue where SAC would perform too many model updates when resuming from a
checkpoint, and too few when using `buffer_init_steps`. (#4038)
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
20 changes: 13 additions & 7 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(
)
self.step = 0

# Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0
self.update_steps = max(1, self.hyperparameters.buffer_init_steps)
self.reward_signal_update_steps = max(1, self.hyperparameters.buffer_init_steps)
# Don't divide by zero
self.update_steps = 1
self.reward_signal_update_steps = 1

self.steps_per_update = self.hyperparameters.steps_per_update
self.reward_signal_steps_per_update = (
Expand Down Expand Up @@ -229,7 +229,9 @@ def _update_sac_policy(self) -> bool:
)

batch_update_stats: Dict[str, list] = defaultdict(list)
while self.step / self.update_steps > self.steps_per_update:
while (
self.step - self.hyperparameters.buffer_init_steps
) / self.update_steps > self.steps_per_update:
logger.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.update_buffer
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:
Expand Down Expand Up @@ -282,9 +284,8 @@ def _update_reward_signals(self) -> None:
)
batch_update_stats: Dict[str, list] = defaultdict(list)
while (
self.step / self.reward_signal_update_steps
> self.reward_signal_steps_per_update
):
self.step - self.hyperparameters.buffer_init_steps
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update:
# Get minibatches for reward signal update if needed
reward_signal_minibatches = {}
for name, signal in self.optimizer.reward_signals.items():
Expand Down Expand Up @@ -327,6 +328,11 @@ def add_policy(
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
# Assume steps were updated at the correct ratio before
self.update_steps = int(max(1, self.step / self.steps_per_update))
self.reward_signal_update_steps = int(
max(1, self.step / self.reward_signal_steps_per_update)
)
self.next_summary_step = self._get_next_summary_step()

def get_policy(self, name_behavior_id: str) -> TFPolicy:
Expand Down
16 changes: 16 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def test_advance(dummy_config):
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
dummy_config.hyperparameters.steps_per_update = 20
dummy_config.hyperparameters.reward_signal_steps_per_update = 20
dummy_config.hyperparameters.buffer_init_steps = 0
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
policy = trainer.create_policy(brain_params.brain_name, brain_params)
Expand Down Expand Up @@ -220,6 +221,21 @@ def test_advance(dummy_config):
with pytest.raises(AgentManagerQueue.Empty):
policy_queue.get_nowait()

# Call add_policy and check that we update the correct number of times.
# This is to emulate a load from checkpoint.
policy = trainer.create_policy(brain_params.brain_name, brain_params)
policy.get_current_step = lambda: 200
trainer.add_policy(brain_params.brain_name, policy)
trainer.optimizer.update = mock.Mock()
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}
trainer.optimizer.update.return_value = {}
trajectory_queue.put(trajectory)
trainer.advance()
# Make sure we did exactly 1 update
assert trainer.optimizer.update.call_count == 1
assert trainer.optimizer.update_reward_signals.call_count == 1


if __name__ == "__main__":
pytest.main()