@@ -138,6 +138,26 @@ def test_bcmodule_update(mock_env, trainer_config):
138
138
env .close ()
139
139
140
140
141
+ # Test with constant pretraining learning rate
142
+ @pytest .mark .parametrize (
143
+ "trainer_config" , [ppo_dummy_config (), sac_dummy_config ()], ids = ["ppo" , "sac" ]
144
+ )
145
+ @mock .patch ("mlagents.envs.environment.UnityEnvironment" )
146
+ def test_bcmodule_constant_lr_update (mock_env , trainer_config ):
147
+ mock_brain = mb .create_mock_3dball_brain ()
148
+ trainer_config ["pretraining" ]["steps" ] = 0
149
+ env , policy = create_policy_with_bc_mock (
150
+ mock_env , mock_brain , trainer_config , False , "test.demo"
151
+ )
152
+ stats = policy .bc_module .update ()
153
+ for _ , item in stats .items ():
154
+ assert isinstance (item , np .float32 )
155
+ old_learning_rate = policy .bc_module .current_lr
156
+
157
+ stats = policy .bc_module .update ()
158
+ assert old_learning_rate == policy .bc_module .current_lr
159
+
160
+
141
161
# Test with RNN
142
162
@pytest .mark .parametrize (
143
163
"trainer_config" , [ppo_dummy_config (), sac_dummy_config ()], ids = ["ppo" , "sac" ]
0 commit comments