Skip to content

Commit d37960c

Browse files
author
Ervin T
authored
[add-fire] Fix masked mean for 2d tensors (#4364)
1 parent 291091a commit d37960c

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

ml-agents/mlagents/trainers/ppo/optimizer_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def ppo_policy_loss(
109109
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
110110
)
111111
policy_loss = -1 * ModelUtils.masked_mean(
112-
torch.min(p_opt_a, p_opt_b).flatten(), loss_masks
112+
torch.min(p_opt_a, p_opt_b), loss_masks
113113
)
114114
return policy_loss
115115

@@ -177,7 +177,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
177177
loss = (
178178
policy_loss
179179
+ 0.5 * value_loss
180-
- decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks)
180+
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
181181
)
182182

183183
# Set optimizer learning rate

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
413413
memories = None
414414
next_memories = None
415415
# Q network memories are 0'ed out, since we don't have them during inference.
416-
q_memories = torch.zeros_like(next_memories)
416+
q_memories = (
417+
torch.zeros_like(next_memories) if next_memories is not None else None
418+
)
417419

418420
vis_obs: List[torch.Tensor] = []
419421
next_vis_obs: List[torch.Tensor] = []

ml-agents/mlagents/trainers/tests/torch/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,9 @@ def test_masked_mean():
214214
masks = torch.tensor([False, False, False, False, False])
215215
mean = ModelUtils.masked_mean(test_input, masks=masks)
216216
assert mean == 0.0
217+
218+
# Make sure it works with 2d arrays of shape (mask_length, N)
219+
test_input = torch.tensor([1, 2, 3, 4, 5]).repeat(2, 1).T
220+
masks = torch.tensor([False, False, True, True, True])
221+
mean = ModelUtils.masked_mean(test_input, masks=masks)
222+
assert mean == 4.0

ml-agents/mlagents/trainers/torch/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,6 @@ def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
293293
:param tensor: Tensor which needs mean computation.
294294
:param masks: Boolean tensor of masks with same dimension as tensor.
295295
"""
296-
return (tensor * masks).sum() / torch.clamp(masks.float().sum(), min=1.0)
296+
return (tensor.T * masks).sum() / torch.clamp(
297+
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0
298+
)

0 commit comments

Comments
 (0)