diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index 63f8561bc3..cf19d2b438 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -137,7 +137,7 @@ def _process_step( action_pre = None action_probs = stored_take_action_outputs["log_probs"][idx] action_mask = stored_decision_step.action_mask - prev_action = self.policy.retrieve_previous_action([global_id])[0, :] + prev_action = self.policy.retrieve_previous_action_single(global_id) experience = AgentExperience( obs=obs, reward=step.reward, diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index b23a819b04..772abe884c 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -353,6 +353,16 @@ def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: action_matrix[index, :] = self.previous_action_dict[agent_id] return action_matrix + def retrieve_previous_action_single(self, agent_id: str) -> np.ndarray: + """ + A more efficient version of retrieve_previous_action() for a single + agent at a time. + """ + prev_action = self.previous_action_dict.get(agent_id) + if prev_action is not None: + return prev_action + return np.zeros(self.num_branches, dtype=np.int) + def remove_previous_action(self, agent_ids): for agent_id in agent_ids: if agent_id in self.previous_action_dict: