Skip to content

Commit ea54e4a

Browse files
[Bug fix] Gym last reward before Done (#3471)
* Fixing #3460 * Addressing comments * Added 2 tests * encapsulate the agent mapping operations (#3481) * encapsulate the agent mapping operations * rename, linear time impl * cleanup * dict.popitem * udpate comments * Update gym-unity/gym_unity/tests/test_gym.py Co-authored-by: Chris Elion <[email protected]>
1 parent ef2b170 commit ea54e4a

File tree

2 files changed

+198
-24
lines changed

2 files changed

+198
-24
lines changed

gym-unity/gym_unity/envs/__init__.py

Lines changed: 122 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import itertools
33
import numpy as np
4-
from typing import Any, Dict, List, Optional, Tuple, Union, Set
4+
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import gym
77
from gym import error, spaces
@@ -74,7 +74,9 @@ def __init__(
7474

7575
self.visual_obs = None
7676
self._n_agents = -1
77-
self._done_agents: Set[int] = set()
77+
78+
self.agent_mapper = AgentIdIndexMapper()
79+
7880
# Save the step result from the last time all Agents requested decisions.
7981
self._previous_step_result: BatchedStepResult = None
8082
self._multiagent = multiagent
@@ -121,6 +123,7 @@ def __init__(
121123
step_result = self._env.get_step_result(self.brain_name)
122124
self._check_agents(step_result.n_agents())
123125
self._previous_step_result = step_result
126+
self.agent_mapper.set_initial_agents(list(self._previous_step_result.agent_id))
124127

125128
# Set observation and action spaces
126129
if self.group_spec.is_action_discrete():
@@ -368,52 +371,58 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
368371
"The number of agents in the scene does not match the expected number."
369372
)
370373

371-
# remove the done Agents
372-
indices_to_keep: List[int] = []
373-
for index, is_done in enumerate(step_result.done):
374-
if not is_done:
375-
indices_to_keep.append(index)
374+
if step_result.n_agents() - sum(step_result.done) != self._n_agents:
375+
raise UnityGymException(
376+
"The number of agents in the scene does not match the expected number."
377+
)
378+
379+
for index, agent_id in enumerate(step_result.agent_id):
380+
if step_result.done[index]:
381+
self.agent_mapper.mark_agent_done(agent_id, step_result.reward[index])
376382

377383
# Set the new AgentDone flags to True
378384
# Note that the corresponding agent_id that gets marked done will be different
379385
# than the original agent that was done, but this is OK since the gym interface
380386
# only cares about the ordering.
381387
for index, agent_id in enumerate(step_result.agent_id):
382388
if not self._previous_step_result.contains_agent(agent_id):
389+
# Register this agent, and get the reward of the previous agent that
390+
# was in its index, so that we can return it to the gym.
391+
last_reward = self.agent_mapper.register_new_agent_id(agent_id)
383392
step_result.done[index] = True
384-
if agent_id in self._done_agents:
385-
step_result.done[index] = True
386-
self._done_agents = set()
393+
step_result.reward[index] = last_reward
394+
387395
self._previous_step_result = step_result # store the new original
388396

397+
# Get a permutation of the agent IDs so that a given ID stays in the same
398+
# index as where it was first seen.
399+
new_id_order = self.agent_mapper.get_id_permutation(list(step_result.agent_id))
400+
389401
_mask: Optional[List[np.array]] = None
390402
if step_result.action_mask is not None:
391403
_mask = []
392404
for mask_index in range(len(step_result.action_mask)):
393-
_mask.append(step_result.action_mask[mask_index][indices_to_keep])
405+
_mask.append(step_result.action_mask[mask_index][new_id_order])
394406
new_obs: List[np.array] = []
395407
for obs_index in range(len(step_result.obs)):
396-
new_obs.append(step_result.obs[obs_index][indices_to_keep])
408+
new_obs.append(step_result.obs[obs_index][new_id_order])
397409
return BatchedStepResult(
398410
obs=new_obs,
399-
reward=step_result.reward[indices_to_keep],
400-
done=step_result.done[indices_to_keep],
401-
max_step=step_result.max_step[indices_to_keep],
402-
agent_id=step_result.agent_id[indices_to_keep],
411+
reward=step_result.reward[new_id_order],
412+
done=step_result.done[new_id_order],
413+
max_step=step_result.max_step[new_id_order],
414+
agent_id=step_result.agent_id[new_id_order],
403415
action_mask=_mask,
404416
)
405417

406418
def _sanitize_action(self, action: np.array) -> np.array:
407-
if self._previous_step_result.n_agents() == self._n_agents:
408-
return action
409419
sanitized_action = np.zeros(
410420
(self._previous_step_result.n_agents(), self.group_spec.action_size)
411421
)
412-
input_index = 0
413-
for index in range(self._previous_step_result.n_agents()):
422+
for index, agent_id in enumerate(self._previous_step_result.agent_id):
414423
if not self._previous_step_result.done[index]:
415-
sanitized_action[index, :] = action[input_index, :]
416-
input_index = input_index + 1
424+
array_index = self.agent_mapper.get_gym_index(agent_id)
425+
sanitized_action[index, :] = action[array_index, :]
417426
return sanitized_action
418427

419428
def _step(self, needs_reset: bool = False) -> BatchedStepResult:
@@ -432,7 +441,9 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult:
432441
"The environment does not have the expected amount of agents."
433442
+ "Some agents did not request decisions at the same time."
434443
)
435-
self._done_agents.update(list(info.agent_id))
444+
for agent_id, reward in zip(info.agent_id, info.reward):
445+
self.agent_mapper.mark_agent_done(agent_id, reward)
446+
436447
self._env.step()
437448
info = self._env.get_step_result(self.brain_name)
438449
return self._sanitize_info(info)
@@ -499,3 +510,91 @@ def lookup_action(self, action):
499510
:return: The List containing the branched actions.
500511
"""
501512
return self.action_lookup[action]
513+
514+
515+
class AgentIdIndexMapper:
516+
def __init__(self) -> None:
517+
self._agent_id_to_gym_index: Dict[int, int] = {}
518+
self._done_agents_index_to_last_reward: Dict[int, float] = {}
519+
520+
def set_initial_agents(self, agent_ids: List[int]) -> None:
521+
"""
522+
Provide the initial list of agent ids for the mapper
523+
"""
524+
for idx, agent_id in enumerate(agent_ids):
525+
self._agent_id_to_gym_index[agent_id] = idx
526+
527+
def mark_agent_done(self, agent_id: int, reward: float) -> None:
528+
"""
529+
Declare the agent done with the corresponding final reward.
530+
"""
531+
gym_index = self._agent_id_to_gym_index.pop(agent_id)
532+
self._done_agents_index_to_last_reward[gym_index] = reward
533+
534+
def register_new_agent_id(self, agent_id: int) -> float:
535+
"""
536+
Adds the new agent ID and returns the reward to use for the previous agent in this index
537+
"""
538+
# Any free index is OK here.
539+
free_index, last_reward = self._done_agents_index_to_last_reward.popitem()
540+
self._agent_id_to_gym_index[agent_id] = free_index
541+
return last_reward
542+
543+
def get_id_permutation(self, agent_ids: List[int]) -> List[int]:
544+
"""
545+
Get the permutation from new agent ids to the order that preserves the positions of previous agents.
546+
The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
547+
"""
548+
# Map the new agent ids to the their index
549+
new_agent_ids_to_index = {
550+
agent_id: idx for idx, agent_id in enumerate(agent_ids)
551+
}
552+
553+
# Make the output list. We don't write to it sequentially, so start with dummy values.
554+
new_permutation = [-1] * len(agent_ids)
555+
556+
# For each agent ID, find the new index of the agent, and write it in the original index.
557+
for agent_id, original_index in self._agent_id_to_gym_index.items():
558+
new_permutation[original_index] = new_agent_ids_to_index[agent_id]
559+
return new_permutation
560+
561+
def get_gym_index(self, agent_id: int) -> int:
562+
"""
563+
Get the gym index for the current agent.
564+
"""
565+
return self._agent_id_to_gym_index[agent_id]
566+
567+
568+
class AgentIdIndexMapperSlow:
569+
"""
570+
Reference implementation of AgentIdIndexMapper.
571+
The operations are O(N^2) so it shouldn't be used for large numbers of agents.
572+
See AgentIdIndexMapper for method descriptions
573+
"""
574+
575+
def __init__(self) -> None:
576+
self._gym_id_order: List[int] = []
577+
self._done_agents_index_to_last_reward: Dict[int, float] = {}
578+
579+
def set_initial_agents(self, agent_ids: List[int]) -> None:
580+
self._gym_id_order = list(agent_ids)
581+
582+
def mark_agent_done(self, agent_id: int, reward: float) -> None:
583+
gym_index = self._gym_id_order.index(agent_id)
584+
self._done_agents_index_to_last_reward[gym_index] = reward
585+
self._gym_id_order[gym_index] = -1
586+
587+
def register_new_agent_id(self, agent_id: int) -> float:
588+
original_index = self._gym_id_order.index(-1)
589+
self._gym_id_order[original_index] = agent_id
590+
reward = self._done_agents_index_to_last_reward.pop(original_index)
591+
return reward
592+
593+
def get_id_permutation(self, agent_ids):
594+
new_id_order = []
595+
for agent_id in self._gym_id_order:
596+
new_id_order.append(agent_ids.index(agent_id))
597+
return new_id_order
598+
599+
def get_gym_index(self, agent_id: int) -> int:
600+
return self._gym_id_order.index(agent_id)

gym-unity/gym_unity/tests/test_gym.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import numpy as np
44

55
from gym import spaces
6-
from gym_unity.envs import UnityEnv, UnityGymException
6+
from gym_unity.envs import (
7+
UnityEnv,
8+
UnityGymException,
9+
AgentIdIndexMapper,
10+
AgentIdIndexMapperSlow,
11+
)
712
from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult
813

914

@@ -84,6 +89,46 @@ def test_gym_wrapper_visual(mock_env, use_uint8):
8489
assert isinstance(info, dict)
8590

8691

92+
@mock.patch("gym_unity.envs.UnityEnvironment")
93+
def test_sanitize_action_shuffled_id(mock_env):
94+
mock_spec = create_mock_group_spec(
95+
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
96+
)
97+
mock_step = create_mock_vector_step_result(num_agents=5)
98+
mock_step.agent_id = np.array(range(5))
99+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
100+
env = UnityEnv(" ", use_visual=False, multiagent=True)
101+
102+
shuffled_step_result = create_mock_vector_step_result(num_agents=5)
103+
shuffled_order = [4, 2, 3, 1, 0]
104+
shuffled_step_result.reward = np.array(shuffled_order)
105+
shuffled_step_result.agent_id = np.array(shuffled_order)
106+
sanitized_result = env._sanitize_info(shuffled_step_result)
107+
for expected_reward, reward in zip(range(5), sanitized_result.reward):
108+
assert expected_reward == reward
109+
for expected_agent_id, agent_id in zip(range(5), sanitized_result.agent_id):
110+
assert expected_agent_id == agent_id
111+
112+
113+
@mock.patch("gym_unity.envs.UnityEnvironment")
114+
def test_sanitize_action_one_agent_done(mock_env):
115+
mock_spec = create_mock_group_spec(
116+
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
117+
)
118+
mock_step = create_mock_vector_step_result(num_agents=5)
119+
mock_step.agent_id = np.array(range(5))
120+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
121+
env = UnityEnv(" ", use_visual=False, multiagent=True)
122+
123+
received_step_result = create_mock_vector_step_result(num_agents=6)
124+
received_step_result.agent_id = np.array(range(6))
125+
# agent #3 (id = 2) is Done
126+
received_step_result.done = np.array([False] * 2 + [True] + [False] * 3)
127+
sanitized_result = env._sanitize_info(received_step_result)
128+
for expected_agent_id, agent_id in zip([0, 1, 5, 3, 4], sanitized_result.agent_id):
129+
assert expected_agent_id == agent_id
130+
131+
87132
# Helper methods
88133

89134

@@ -143,3 +188,33 @@ def setup_mock_unityenvironment(mock_env, mock_spec, mock_result):
143188
mock_env.return_value.get_agent_groups.return_value = ["MockBrain"]
144189
mock_env.return_value.get_agent_group_spec.return_value = mock_spec
145190
mock_env.return_value.get_step_result.return_value = mock_result
191+
192+
193+
@pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow])
194+
def test_agent_id_index_mapper(mapper_cls):
195+
mapper = mapper_cls()
196+
initial_agent_ids = [1001, 1002, 1003, 1004]
197+
mapper.set_initial_agents(initial_agent_ids)
198+
199+
# Mark some agents as done with their last rewards.
200+
mapper.mark_agent_done(1001, 42.0)
201+
mapper.mark_agent_done(1004, 1337.0)
202+
203+
# Now add new agents, and get the rewards of the agent they replaced.
204+
old_reward1 = mapper.register_new_agent_id(2001)
205+
old_reward2 = mapper.register_new_agent_id(2002)
206+
207+
# Order of the rewards don't matter
208+
assert {old_reward1, old_reward2} == {42.0, 1337.0}
209+
210+
new_agent_ids = [1002, 1003, 2001, 2002]
211+
permutation = mapper.get_id_permutation(new_agent_ids)
212+
# Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats.
213+
assert set(permutation) == set(range(0, 4))
214+
215+
# For initial agents that were in the initial group, they need to be in the same slot.
216+
# Agents that were added later can appear in any free slot.
217+
permuted_ids = [new_agent_ids[i] for i in permutation]
218+
for idx, agent_id in enumerate(initial_agent_ids):
219+
if agent_id in permuted_ids:
220+
assert permuted_ids[idx] == agent_id

0 commit comments

Comments
 (0)