Skip to content

Commit 9af6127

Browse files
Ervin TChris Elion
authored andcommitted
Fix bug where constant LR in pretraining will throw TF error (#2977)
1 parent 226905f commit 9af6127

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

ml-agents/mlagents/trainers/components/bc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def create_loss(self, learning_rate: float, anneal_steps: int) -> None:
7474
power=1.0,
7575
)
7676
else:
77-
self.annealed_learning_rate = learning_rate
77+
self.annealed_learning_rate = tf.Variable(learning_rate)
7878

7979
optimizer = tf.train.AdamOptimizer(learning_rate=self.annealed_learning_rate)
8080
self.update_batch = optimizer.minimize(self.loss)

ml-agents/mlagents/trainers/tests/test_bcmodule.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,26 @@ def test_bcmodule_update(mock_env, trainer_config):
138138
env.close()
139139

140140

141+
# Test with constant pretraining learning rate
142+
@pytest.mark.parametrize(
143+
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
144+
)
145+
@mock.patch("mlagents.envs.environment.UnityEnvironment")
146+
def test_bcmodule_constant_lr_update(mock_env, trainer_config):
147+
mock_brain = mb.create_mock_3dball_brain()
148+
trainer_config["pretraining"]["steps"] = 0
149+
env, policy = create_policy_with_bc_mock(
150+
mock_env, mock_brain, trainer_config, False, "test.demo"
151+
)
152+
stats = policy.bc_module.update()
153+
for _, item in stats.items():
154+
assert isinstance(item, np.float32)
155+
old_learning_rate = policy.bc_module.current_lr
156+
157+
stats = policy.bc_module.update()
158+
assert old_learning_rate == policy.bc_module.current_lr
159+
160+
141161
# Test with RNN
142162
@pytest.mark.parametrize(
143163
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]

0 commit comments

Comments
 (0)