diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 02ea27741a..7fc27cb44f 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -109,7 +109,7 @@ def ppo_policy_loss( torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage ) policy_loss = -1 * ModelUtils.masked_mean( - torch.min(p_opt_a, p_opt_b).flatten(), loss_masks + torch.min(p_opt_a, p_opt_b), loss_masks ) return policy_loss @@ -177,7 +177,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: loss = ( policy_loss + 0.5 * value_loss - - decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks) + - decay_bet * ModelUtils.masked_mean(entropy, loss_masks) ) # Set optimizer learning rate diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index ff0eb758de..9ca71be3bf 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -413,7 +413,9 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories = None next_memories = None # Q network memories are 0'ed out, since we don't have them during inference. - q_memories = torch.zeros_like(next_memories) + q_memories = ( + torch.zeros_like(next_memories) if next_memories is not None else None + ) vis_obs: List[torch.Tensor] = [] next_vis_obs: List[torch.Tensor] = [] diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index 0275581d08..ae52456f3e 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -214,3 +214,9 @@ def test_masked_mean(): masks = torch.tensor([False, False, False, False, False]) mean = ModelUtils.masked_mean(test_input, masks=masks) assert mean == 0.0 + + # Make sure it works with 2d arrays of shape (mask_length, N) + test_input = torch.tensor([1, 2, 3, 4, 5]).repeat(2, 1).T + masks = torch.tensor([False, False, True, True, True]) + mean = ModelUtils.masked_mean(test_input, masks=masks) + assert mean == 4.0 diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 0e855ea79b..570fa7b7bf 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -293,4 +293,6 @@ def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: :param tensor: Tensor which needs mean computation. :param masks: Boolean tensor of masks with same dimension as tensor. """ - return (tensor * masks).sum() / torch.clamp(masks.float().sum(), min=1.0) + return (tensor.T * masks).sum() / torch.clamp( + (torch.ones_like(tensor.T) * masks).float().sum(), min=1.0 + )