Skip to content

Commit 437d04e

Browse files
author
Ervin T
authored
Don't run value during policy evaluate, optimized soft update function (#4501)
* Don't run value during inference * Execute critic with LSTM * Address comments * Unformat * Optimized soft update * Move soft update to model utils * Add test for soft update
1 parent 2719971 commit 437d04e

File tree

7 files changed

+75
-46
lines changed

7 files changed

+75
-46
lines changed

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,26 @@ def sample_actions(
124124
memories: Optional[torch.Tensor] = None,
125125
seq_len: int = 1,
126126
all_log_probs: bool = False,
127-
) -> Tuple[
128-
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
129-
]:
127+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
130128
"""
129+
:param vec_obs: List of vector observations.
130+
:param vis_obs: List of visual observations.
131+
:param masks: Loss masks for RNN, else None.
132+
:param memories: Input memories when using RNN, else None.
133+
:param seq_len: Sequence length when using RNN.
131134
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
135+
:return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and
136+
output memories, all as Torch Tensors.
132137
"""
133-
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
134-
vec_obs, vis_obs, masks, memories, seq_len
135-
)
138+
if memories is None:
139+
dists, memories = self.actor_critic.get_dists(
140+
vec_obs, vis_obs, masks, memories, seq_len
141+
)
142+
else:
143+
# If we're using LSTM. we need to execute the values to get the critic memories
144+
dists, _, memories = self.actor_critic.get_dist_and_value(
145+
vec_obs, vis_obs, masks, memories, seq_len
146+
)
136147
action_list = self.actor_critic.sample_action(dists)
137148
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
138149
action_list, dists
@@ -143,13 +154,7 @@ def sample_actions(
143154
else:
144155
actions = actions[:, 0, :]
145156

146-
return (
147-
actions,
148-
all_logs if all_log_probs else log_probs,
149-
entropies,
150-
value_heads,
151-
memories,
152-
)
157+
return (actions, all_logs if all_log_probs else log_probs, entropies, memories)
153158

154159
def evaluate_actions(
155160
self,
@@ -189,18 +194,14 @@ def evaluate(
189194

190195
run_out = {}
191196
with torch.no_grad():
192-
action, log_probs, entropy, value_heads, memories = self.sample_actions(
197+
action, log_probs, entropy, memories = self.sample_actions(
193198
vec_obs, vis_obs, masks=masks, memories=memories
194199
)
195200
run_out["action"] = ModelUtils.to_numpy(action)
196201
run_out["pre_action"] = ModelUtils.to_numpy(action)
197202
# Todo - make pre_action difference
198203
run_out["log_probs"] = ModelUtils.to_numpy(log_probs)
199204
run_out["entropy"] = ModelUtils.to_numpy(entropy)
200-
run_out["value_heads"] = {
201-
name: ModelUtils.to_numpy(t) for name, t in value_heads.items()
202-
}
203-
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
204205
run_out["learning_rate"] = 0.0
205206
if self.use_recurrent:
206207
run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
113113
self.policy.behavior_spec.observation_shapes,
114114
policy_network_settings,
115115
)
116-
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)
116+
ModelUtils.soft_update(
117+
self.policy.actor_critic.critic, self.target_network, 1.0
118+
)
117119

118120
self._log_ent_coef = torch.nn.Parameter(
119121
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),
@@ -201,12 +203,6 @@ def sac_q_loss(
201203
q2_loss = torch.mean(torch.stack(q2_losses))
202204
return q1_loss, q2_loss
203205

204-
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None:
205-
for source_param, target_param in zip(source.parameters(), target.parameters()):
206-
target_param.data.copy_(
207-
target_param.data * (1.0 - tau) + source_param.data * tau
208-
)
209-
210206
def sac_value_loss(
211207
self,
212208
log_probs: torch.Tensor,
@@ -441,20 +437,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
441437
self.target_network.network_body.copy_normalization(
442438
self.policy.actor_critic.network_body
443439
)
444-
(
445-
sampled_actions,
446-
log_probs,
447-
entropies,
448-
sampled_values,
449-
_,
450-
) = self.policy.sample_actions(
440+
(sampled_actions, log_probs, _, _) = self.policy.sample_actions(
451441
vec_obs,
452442
vis_obs,
453443
masks=act_masks,
454444
memories=memories,
455445
seq_len=self.policy.sequence_length,
456446
all_log_probs=not self.policy.use_continuous_act,
457447
)
448+
value_estimates, _ = self.policy.actor_critic.critic_pass(
449+
vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length
450+
)
458451
if self.policy.use_continuous_act:
459452
squeezed_actions = actions.squeeze(-1)
460453
q1p_out, q2p_out = self.value_network(
@@ -504,7 +497,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
504497
q1_stream, q2_stream, target_values, dones, rewards, masks
505498
)
506499
value_loss = self.sac_value_loss(
507-
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete
500+
log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete
508501
)
509502
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
510503
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
@@ -528,7 +521,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
528521
self.entropy_optimizer.step()
529522

530523
# Update target network
531-
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
524+
ModelUtils.soft_update(
525+
self.policy.actor_critic.critic, self.target_network, self.tau
526+
)
532527
update_stats = {
533528
"Losses/Policy Loss": policy_loss.item(),
534529
"Losses/Value Loss": value_loss.item(),

ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
7979
).unsqueeze(0)
8080

8181
with torch.no_grad():
82-
_, log_probs1, _, _, _ = policy1.sample_actions(
82+
_, log_probs1, _, _ = policy1.sample_actions(
8383
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
8484
)
85-
_, log_probs2, _, _, _ = policy2.sample_actions(
85+
_, log_probs2, _, _ = policy2.sample_actions(
8686
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
8787
)
8888

ml-agents/mlagents/trainers/tests/torch/test_policy.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,7 @@ def test_sample_actions(rnn, visual, discrete):
121121
if len(memories) > 0:
122122
memories = torch.stack(memories).unsqueeze(0)
123123

124-
(
125-
sampled_actions,
126-
log_probs,
127-
entropies,
128-
sampled_values,
129-
memories,
130-
) = policy.sample_actions(
124+
(sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
131125
vec_obs,
132126
vis_obs,
133127
masks=act_masks,
@@ -143,8 +137,6 @@ def test_sample_actions(rnn, visual, discrete):
143137
else:
144138
assert log_probs.shape == (64, policy.behavior_spec.action_shape)
145139
assert entropies.shape == (64, policy.behavior_spec.action_size)
146-
for val in sampled_values.values():
147-
assert val.shape == (64,)
148140

149141
if rnn:
150142
assert memories.shape == (1, 1, policy.m_size)

ml-agents/mlagents/trainers/tests/torch/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,20 @@ def test_masked_mean():
208208
masks = torch.tensor([False, False, True, True, True])
209209
mean = ModelUtils.masked_mean(test_input, masks=masks)
210210
assert mean == 4.0
211+
212+
213+
def test_soft_update():
214+
class TestModule(torch.nn.Module):
215+
def __init__(self, vals):
216+
super().__init__()
217+
self.parameter = torch.nn.Parameter(torch.ones(5, 5, 5) * vals)
218+
219+
tm1 = TestModule(0)
220+
tm2 = TestModule(1)
221+
tm3 = TestModule(2)
222+
223+
ModelUtils.soft_update(tm1, tm3, tau=0.5)
224+
assert torch.equal(tm3.parameter, torch.ones(5, 5, 5))
225+
226+
ModelUtils.soft_update(tm1, tm2, tau=1.0)
227+
assert torch.equal(tm2.parameter, tm1.parameter)

ml-agents/mlagents/trainers/torch/components/bc/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _update_batch(
164164
else:
165165
vis_obs = []
166166

167-
selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions(
167+
selected_actions, all_log_probs, _, _ = self.policy.sample_actions(
168168
vec_obs,
169169
vis_obs,
170170
masks=act_masks,

ml-agents/mlagents/trainers/torch/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,27 @@ def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
301301
return (tensor.T * masks).sum() / torch.clamp(
302302
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0
303303
)
304+
305+
@staticmethod
306+
def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
307+
"""
308+
Performs an in-place polyak update of the target module based on the source,
309+
by a ratio of tau. Note that source and target modules must have the same
310+
parameters, where:
311+
target = tau * source + (1-tau) * target
312+
:param source: Source module whose parameters will be used.
313+
:param target: Target module whose parameters will be updated.
314+
:param tau: Percentage of source parameters to use in average. Setting tau to
315+
1 will copy the source parameters to the target.
316+
"""
317+
with torch.no_grad():
318+
for source_param, target_param in zip(
319+
source.parameters(), target.parameters()
320+
):
321+
target_param.data.mul_(1.0 - tau)
322+
torch.add(
323+
target_param.data,
324+
source_param.data,
325+
alpha=tau,
326+
out=target_param.data,
327+
)

0 commit comments

Comments
 (0)