From 2281ba0a7d142d9bbf6633d883c8251bd36dab1c Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 5 Aug 2020 15:28:04 -0700 Subject: [PATCH] Fix non-LSTM separateactorcritic --- ml-agents/mlagents/trainers/tests/torch/test_networks.py | 4 +++- ml-agents/mlagents/trainers/torch/networks.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index 6d6c62d2ad..b6b93698f1 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -188,7 +188,9 @@ def test_actor_critic(ac_type, lstm): ) else: sample_obs = torch.ones((1, obs_size)) - memories = None + memories = torch.tensor([]) + # memories isn't always set to None, the network should be able to + # deal with that. # Test critic pass value_out = actor.critic_pass([sample_obs], [], memories=memories) for stream in stream_names: diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index e6bc0f3385..15f0d92e2d 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -428,7 +428,7 @@ def critic_pass( vis_inputs: List[torch.Tensor], memories: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - if memories is not None: + if self.use_lstm: # Use only the back half of memories for critic _, critic_mem = torch.split(memories, self.half_mem_size, -1) else: @@ -446,7 +446,7 @@ def get_dist_and_value( memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: - if memories is not None: + if self.use_lstm: # Use only the back half of memories for critic and actor actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) else: