1
1
import logging
2
2
import itertools
3
3
import numpy as np
4
- from typing import Any , Dict , List , Optional , Tuple , Union , Set
4
+ from typing import Any , Dict , List , Optional , Tuple , Union
5
5
6
6
import gym
7
7
from gym import error , spaces
@@ -74,7 +74,9 @@ def __init__(
74
74
75
75
self .visual_obs = None
76
76
self ._n_agents = - 1
77
- self ._done_agents : Set [int ] = set ()
77
+
78
+ self .agent_mapper = AgentIdIndexMapper ()
79
+
78
80
# Save the step result from the last time all Agents requested decisions.
79
81
self ._previous_step_result : BatchedStepResult = None
80
82
self ._multiagent = multiagent
@@ -121,6 +123,7 @@ def __init__(
121
123
step_result = self ._env .get_step_result (self .brain_name )
122
124
self ._check_agents (step_result .n_agents ())
123
125
self ._previous_step_result = step_result
126
+ self .agent_mapper .set_initial_agents (list (self ._previous_step_result .agent_id ))
124
127
125
128
# Set observation and action spaces
126
129
if self .group_spec .is_action_discrete ():
@@ -368,52 +371,58 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
368
371
"The number of agents in the scene does not match the expected number."
369
372
)
370
373
371
- # remove the done Agents
372
- indices_to_keep : List [int ] = []
373
- for index , is_done in enumerate (step_result .done ):
374
- if not is_done :
375
- indices_to_keep .append (index )
374
+ if step_result .n_agents () - sum (step_result .done ) != self ._n_agents :
375
+ raise UnityGymException (
376
+ "The number of agents in the scene does not match the expected number."
377
+ )
378
+
379
+ for index , agent_id in enumerate (step_result .agent_id ):
380
+ if step_result .done [index ]:
381
+ self .agent_mapper .mark_agent_done (agent_id , step_result .reward [index ])
376
382
377
383
# Set the new AgentDone flags to True
378
384
# Note that the corresponding agent_id that gets marked done will be different
379
385
# than the original agent that was done, but this is OK since the gym interface
380
386
# only cares about the ordering.
381
387
for index , agent_id in enumerate (step_result .agent_id ):
382
388
if not self ._previous_step_result .contains_agent (agent_id ):
389
+ # Register this agent, and get the reward of the previous agent that
390
+ # was in its index, so that we can return it to the gym.
391
+ last_reward = self .agent_mapper .register_new_agent_id (agent_id )
383
392
step_result .done [index ] = True
384
- if agent_id in self ._done_agents :
385
- step_result .done [index ] = True
386
- self ._done_agents = set ()
393
+ step_result .reward [index ] = last_reward
394
+
387
395
self ._previous_step_result = step_result # store the new original
388
396
397
+ # Get a permutation of the agent IDs so that a given ID stays in the same
398
+ # index as where it was first seen.
399
+ new_id_order = self .agent_mapper .get_id_permutation (list (step_result .agent_id ))
400
+
389
401
_mask : Optional [List [np .array ]] = None
390
402
if step_result .action_mask is not None :
391
403
_mask = []
392
404
for mask_index in range (len (step_result .action_mask )):
393
- _mask .append (step_result .action_mask [mask_index ][indices_to_keep ])
405
+ _mask .append (step_result .action_mask [mask_index ][new_id_order ])
394
406
new_obs : List [np .array ] = []
395
407
for obs_index in range (len (step_result .obs )):
396
- new_obs .append (step_result .obs [obs_index ][indices_to_keep ])
408
+ new_obs .append (step_result .obs [obs_index ][new_id_order ])
397
409
return BatchedStepResult (
398
410
obs = new_obs ,
399
- reward = step_result .reward [indices_to_keep ],
400
- done = step_result .done [indices_to_keep ],
401
- max_step = step_result .max_step [indices_to_keep ],
402
- agent_id = step_result .agent_id [indices_to_keep ],
411
+ reward = step_result .reward [new_id_order ],
412
+ done = step_result .done [new_id_order ],
413
+ max_step = step_result .max_step [new_id_order ],
414
+ agent_id = step_result .agent_id [new_id_order ],
403
415
action_mask = _mask ,
404
416
)
405
417
406
418
def _sanitize_action (self , action : np .array ) -> np .array :
407
- if self ._previous_step_result .n_agents () == self ._n_agents :
408
- return action
409
419
sanitized_action = np .zeros (
410
420
(self ._previous_step_result .n_agents (), self .group_spec .action_size )
411
421
)
412
- input_index = 0
413
- for index in range (self ._previous_step_result .n_agents ()):
422
+ for index , agent_id in enumerate (self ._previous_step_result .agent_id ):
414
423
if not self ._previous_step_result .done [index ]:
415
- sanitized_action [ index , :] = action [ input_index , :]
416
- input_index = input_index + 1
424
+ array_index = self . agent_mapper . get_gym_index ( agent_id )
425
+ sanitized_action [ index , :] = action [ array_index , :]
417
426
return sanitized_action
418
427
419
428
def _step (self , needs_reset : bool = False ) -> BatchedStepResult :
@@ -432,7 +441,9 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult:
432
441
"The environment does not have the expected amount of agents."
433
442
+ "Some agents did not request decisions at the same time."
434
443
)
435
- self ._done_agents .update (list (info .agent_id ))
444
+ for agent_id , reward in zip (info .agent_id , info .reward ):
445
+ self .agent_mapper .mark_agent_done (agent_id , reward )
446
+
436
447
self ._env .step ()
437
448
info = self ._env .get_step_result (self .brain_name )
438
449
return self ._sanitize_info (info )
@@ -499,3 +510,91 @@ def lookup_action(self, action):
499
510
:return: The List containing the branched actions.
500
511
"""
501
512
return self .action_lookup [action ]
513
+
514
+
515
+ class AgentIdIndexMapper :
516
+ def __init__ (self ) -> None :
517
+ self ._agent_id_to_gym_index : Dict [int , int ] = {}
518
+ self ._done_agents_index_to_last_reward : Dict [int , float ] = {}
519
+
520
+ def set_initial_agents (self , agent_ids : List [int ]) -> None :
521
+ """
522
+ Provide the initial list of agent ids for the mapper
523
+ """
524
+ for idx , agent_id in enumerate (agent_ids ):
525
+ self ._agent_id_to_gym_index [agent_id ] = idx
526
+
527
+ def mark_agent_done (self , agent_id : int , reward : float ) -> None :
528
+ """
529
+ Declare the agent done with the corresponding final reward.
530
+ """
531
+ gym_index = self ._agent_id_to_gym_index .pop (agent_id )
532
+ self ._done_agents_index_to_last_reward [gym_index ] = reward
533
+
534
+ def register_new_agent_id (self , agent_id : int ) -> float :
535
+ """
536
+ Adds the new agent ID and returns the reward to use for the previous agent in this index
537
+ """
538
+ # Any free index is OK here.
539
+ free_index , last_reward = self ._done_agents_index_to_last_reward .popitem ()
540
+ self ._agent_id_to_gym_index [agent_id ] = free_index
541
+ return last_reward
542
+
543
+ def get_id_permutation (self , agent_ids : List [int ]) -> List [int ]:
544
+ """
545
+ Get the permutation from new agent ids to the order that preserves the positions of previous agents.
546
+ The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
547
+ """
548
+ # Map the new agent ids to the their index
549
+ new_agent_ids_to_index = {
550
+ agent_id : idx for idx , agent_id in enumerate (agent_ids )
551
+ }
552
+
553
+ # Make the output list. We don't write to it sequentially, so start with dummy values.
554
+ new_permutation = [- 1 ] * len (agent_ids )
555
+
556
+ # For each agent ID, find the new index of the agent, and write it in the original index.
557
+ for agent_id , original_index in self ._agent_id_to_gym_index .items ():
558
+ new_permutation [original_index ] = new_agent_ids_to_index [agent_id ]
559
+ return new_permutation
560
+
561
+ def get_gym_index (self , agent_id : int ) -> int :
562
+ """
563
+ Get the gym index for the current agent.
564
+ """
565
+ return self ._agent_id_to_gym_index [agent_id ]
566
+
567
+
568
+ class AgentIdIndexMapperSlow :
569
+ """
570
+ Reference implementation of AgentIdIndexMapper.
571
+ The operations are O(N^2) so it shouldn't be used for large numbers of agents.
572
+ See AgentIdIndexMapper for method descriptions
573
+ """
574
+
575
+ def __init__ (self ) -> None :
576
+ self ._gym_id_order : List [int ] = []
577
+ self ._done_agents_index_to_last_reward : Dict [int , float ] = {}
578
+
579
+ def set_initial_agents (self , agent_ids : List [int ]) -> None :
580
+ self ._gym_id_order = list (agent_ids )
581
+
582
+ def mark_agent_done (self , agent_id : int , reward : float ) -> None :
583
+ gym_index = self ._gym_id_order .index (agent_id )
584
+ self ._done_agents_index_to_last_reward [gym_index ] = reward
585
+ self ._gym_id_order [gym_index ] = - 1
586
+
587
+ def register_new_agent_id (self , agent_id : int ) -> float :
588
+ original_index = self ._gym_id_order .index (- 1 )
589
+ self ._gym_id_order [original_index ] = agent_id
590
+ reward = self ._done_agents_index_to_last_reward .pop (original_index )
591
+ return reward
592
+
593
+ def get_id_permutation (self , agent_ids ):
594
+ new_id_order = []
595
+ for agent_id in self ._gym_id_order :
596
+ new_id_order .append (agent_ids .index (agent_id ))
597
+ return new_id_order
598
+
599
+ def get_gym_index (self , agent_id : int ) -> int :
600
+ return self ._gym_id_order .index (agent_id )
0 commit comments