diff --git a/ml-agents/mlagents/trainers/ghost/trainer.py b/ml-agents/mlagents/trainers/ghost/trainer.py index 735926f2e5..ecb93a838f 100644 --- a/ml-agents/mlagents/trainers/ghost/trainer.py +++ b/ml-agents/mlagents/trainers/ghost/trainer.py @@ -1,7 +1,8 @@ # # Unity ML-Agents Toolkit # ## ML-Agent Learning (Ghost Trainer) -from typing import Deque, Dict, List, cast +from collections import defaultdict +from typing import Deque, Dict, DefaultDict, List, cast import numpy as np @@ -68,9 +69,9 @@ def __init__( self._internal_trajectory_queues: Dict[str, AgentManagerQueue[Trajectory]] = {} self._internal_policy_queues: Dict[str, AgentManagerQueue[Policy]] = {} - self._team_to_name_to_policy_queue: Dict[ + self._team_to_name_to_policy_queue: DefaultDict[ int, Dict[str, AgentManagerQueue[Policy]] - ] = {} + ] = defaultdict(dict) self._name_to_parsed_behavior_id: Dict[str, BehaviorIdentifiers] = {} @@ -413,14 +414,9 @@ def publish_policy_queue(self, policy_queue: AgentManagerQueue[Policy]) -> None: """ super().publish_policy_queue(policy_queue) parsed_behavior_id = self._name_to_parsed_behavior_id[policy_queue.behavior_id] - try: - self._team_to_name_to_policy_queue[parsed_behavior_id.team_id][ - parsed_behavior_id.brain_name - ] = policy_queue - except KeyError: - self._team_to_name_to_policy_queue[parsed_behavior_id.team_id] = { - parsed_behavior_id.brain_name: policy_queue - } + self._team_to_name_to_policy_queue[parsed_behavior_id.team_id][ + parsed_behavior_id.brain_name + ] = policy_queue if parsed_behavior_id.team_id == self.wrapped_trainer_team: # With a future multiagent trainer, this will be indexed by 'role' internal_policy_queue: AgentManagerQueue[Policy] = AgentManagerQueue( diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index fc7e7390ad..201a761f88 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -124,9 +124,9 @@ def _create_trainer_and_manager( parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id) brain_name = parsed_behavior_id.brain_name trainerthread = None - try: + if brain_name in self.trainers: trainer = self.trainers[brain_name] - except KeyError: + else: trainer = self.trainer_factory.generate(brain_name) self.trainers[brain_name] = trainer if trainer.threaded: