Skip to content

Commit 9e7eea5

Browse files
author
Ervin T
authored
Replace BrainInfos with BatchedStepResult (#3207)
1 parent 5adea00 commit 9e7eea5

36 files changed

+637
-923
lines changed

gym-unity/gym_unity/envs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def step(self, action: List[Any]) -> GymStepResult:
177177
observation (object/list): agent's observation of the current environment
178178
reward (float/list) : amount of reward returned after previous action
179179
done (boolean/list): whether the episode has ended.
180-
info (dict): contains auxiliary diagnostic information, including BrainInfo.
180+
info (dict): contains auxiliary diagnostic information, including BatchedStepResult.
181181
"""
182182

183183
# Use random actions for all other agents in environment.

gym-unity/gym_unity/tests/test_gym.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
@mock.patch("gym_unity.envs.UnityEnvironment")
1111
def test_gym_wrapper(mock_env):
12-
mock_brain = create_mock_group_spec()
13-
mock_braininfo = create_mock_vector_step_result()
14-
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
12+
mock_spec = create_mock_group_spec()
13+
mock_step = create_mock_vector_step_result()
14+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
1515

1616
env = UnityEnv(" ", use_visual=False, multiagent=False)
1717
assert isinstance(env, UnityEnv)
@@ -28,9 +28,9 @@ def test_gym_wrapper(mock_env):
2828

2929
@mock.patch("gym_unity.envs.UnityEnvironment")
3030
def test_multi_agent(mock_env):
31-
mock_brain = create_mock_group_spec()
32-
mock_braininfo = create_mock_vector_step_result(num_agents=2)
33-
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
31+
mock_spec = create_mock_group_spec()
32+
mock_step = create_mock_vector_step_result(num_agents=2)
33+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
3434

3535
with pytest.raises(UnityGymException):
3636
UnityEnv(" ", multiagent=False)
@@ -47,11 +47,11 @@ def test_multi_agent(mock_env):
4747

4848
@mock.patch("gym_unity.envs.UnityEnvironment")
4949
def test_branched_flatten(mock_env):
50-
mock_brain = create_mock_group_spec(
50+
mock_spec = create_mock_group_spec(
5151
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
5252
)
53-
mock_braininfo = create_mock_vector_step_result(num_agents=1)
54-
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
53+
mock_step = create_mock_vector_step_result(num_agents=1)
54+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
5555

5656
env = UnityEnv(" ", use_visual=False, multiagent=False, flatten_branched=True)
5757
assert isinstance(env.action_space, spaces.Discrete)
@@ -67,9 +67,9 @@ def test_branched_flatten(mock_env):
6767
@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
6868
@mock.patch("gym_unity.envs.UnityEnvironment")
6969
def test_gym_wrapper_visual(mock_env, use_uint8):
70-
mock_brain = create_mock_group_spec(number_visual_observations=1)
71-
mock_braininfo = create_mock_vector_step_result(number_visual_observations=1)
72-
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
70+
mock_spec = create_mock_group_spec(number_visual_observations=1)
71+
mock_step = create_mock_vector_step_result(number_visual_observations=1)
72+
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
7373

7474
env = UnityEnv(" ", use_visual=True, multiagent=False, uint8_visual=use_uint8)
7575
assert isinstance(env, UnityEnv)
@@ -117,10 +117,10 @@ def create_mock_group_spec(
117117

118118
def create_mock_vector_step_result(num_agents=1, number_visual_observations=0):
119119
"""
120-
Creates a mock BrainInfo with vector observations. Imitates constant
120+
Creates a mock BatchedStepResult with vector observations. Imitates constant
121121
vector observations, rewards, dones, and agents.
122122
123-
:int num_agents: Number of "agents" to imitate in your BrainInfo values.
123+
:int num_agents: Number of "agents" to imitate in your BatchedStepResult values.
124124
"""
125125
obs = [np.array([num_agents * [1, 2, 3]])]
126126
if number_visual_observations:
@@ -134,7 +134,7 @@ def create_mock_vector_step_result(num_agents=1, number_visual_observations=0):
134134
def setup_mock_unityenvironment(mock_env, mock_spec, mock_result):
135135
"""
136136
Takes a mock UnityEnvironment and adds the appropriate properties, defined by the mock
137-
BrainParameters and BrainInfo.
137+
GroupSpec and BatchedStepResult.
138138
139139
:Mock mock_env: A mock UnityEnvironment, usually empty.
140140
:Mock mock_spec: An AgentGroupSpec object that specifies the params of this environment.

ml-agents-envs/mlagents_envs/base_env.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,24 @@ def __init__(self, obs, reward, done, max_step, agent_id, action_mask):
9090
self.max_step: np.ndarray = max_step
9191
self.agent_id: np.ndarray = agent_id
9292
self.action_mask: Optional[List[np.ndarray]] = action_mask
93-
self._agent_id_to_index: Optional[Dict[int, int]] = None
93+
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
9494

95-
def contains_agent(self, agent_id: AgentId) -> bool:
95+
@property
96+
def agent_id_to_index(self) -> Dict[AgentId, int]:
97+
"""
98+
Returns the index of the agent_id in this BatchedStepResult, and
99+
-1 if agent_id is not in this BatchedStepResult.
100+
:param agent_id: The id of the agent
101+
:returns: The index of the agent_id, and -1 if not found.
102+
"""
96103
if self._agent_id_to_index is None:
97104
self._agent_id_to_index = {}
98105
for a_idx, a_id in enumerate(self.agent_id):
99106
self._agent_id_to_index[a_id] = a_idx
100-
return agent_id in self._agent_id_to_index
107+
return self._agent_id_to_index
108+
109+
def contains_agent(self, agent_id: AgentId) -> bool:
110+
return agent_id in self.agent_id_to_index
101111

102112
def get_agent_step_result(self, agent_id: AgentId) -> StepResult:
103113
"""
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import NamedTuple, Any, Dict, List
22
import numpy as np
3+
from mlagents_envs.base_env import AgentId
34

45
ActionInfoOutputs = Dict[str, np.ndarray]
56

@@ -8,4 +9,4 @@ class ActionInfo(NamedTuple):
89
action: Any
910
value: Any
1011
outputs: ActionInfoOutputs
11-
agents: List[str]
12+
agent_ids: List[AgentId]

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import sys
2+
import numpy as np
23
from typing import List, Dict, Deque, TypeVar, Generic
34
from collections import defaultdict, Counter, deque
45

6+
from mlagents_envs.base_env import BatchedStepResult
57
from mlagents.trainers.trajectory import Trajectory, AgentExperience
6-
from mlagents.trainers.brain import BrainInfo
78
from mlagents.trainers.tf_policy import TFPolicy
89
from mlagents.trainers.policy import Policy
910
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
1011
from mlagents.trainers.stats import StatsReporter
12+
from mlagents.trainers.env_manager import get_global_agent_id
1113

1214
T = TypeVar("T")
1315

@@ -35,7 +37,7 @@ def __init__(
3537
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
3638
"""
3739
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
38-
self.last_brain_info: Dict[str, BrainInfo] = {}
40+
self.last_step_result: Dict[str, BatchedStepResult] = {}
3941
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
4042
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
4143
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {}
@@ -50,12 +52,15 @@ def __init__(
5052
self.behavior_id = behavior_id
5153

5254
def add_experiences(
53-
self, curr_info: BrainInfo, previous_action: ActionInfo
55+
self,
56+
batched_step_result: BatchedStepResult,
57+
worker_id: int,
58+
previous_action: ActionInfo,
5459
) -> None:
5560
"""
5661
Adds experiences to each agent's experience history.
57-
:param curr_info: current BrainInfo.
58-
:param previous_action: The return value of the Policy's get_action method.
62+
:param batched_step_result: current BatchedStepResult.
63+
:param previous_action: The outputs of the Policy's get_action method.
5964
"""
6065
take_action_outputs = previous_action.outputs
6166
if take_action_outputs:
@@ -65,99 +70,101 @@ def add_experiences(
6570
"Policy/Learning Rate", take_action_outputs["learning_rate"]
6671
)
6772

68-
for agent_id in previous_action.agents:
69-
self.last_take_action_outputs[agent_id] = take_action_outputs
70-
71-
# Store the environment reward
72-
tmp_environment_reward = curr_info.rewards
73-
74-
for agent_idx, agent_id in enumerate(curr_info.agents):
75-
stored_info = self.last_brain_info.get(agent_id, None)
73+
# Make unique agent_ids that are global across workers
74+
action_global_agent_ids = [
75+
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
76+
]
77+
for global_id in action_global_agent_ids:
78+
self.last_take_action_outputs[global_id] = take_action_outputs
79+
80+
for _id in np.nditer(batched_step_result.agent_id): # Explicit numpy iteration
81+
local_id = int(
82+
_id
83+
) # Needed for mypy to pass since ndarray has no content type
84+
curr_agent_step = batched_step_result.get_agent_step_result(local_id)
85+
global_id = get_global_agent_id(worker_id, local_id)
86+
stored_step = self.last_step_result.get(global_id, None)
7687
stored_take_action_outputs = self.last_take_action_outputs.get(
77-
agent_id, None
88+
global_id, None
7889
)
79-
if stored_info is not None and stored_take_action_outputs is not None:
80-
prev_idx = stored_info.agents.index(agent_id)
81-
obs = []
82-
if not stored_info.local_done[prev_idx]:
83-
for i, _ in enumerate(stored_info.visual_observations):
84-
obs.append(stored_info.visual_observations[i][prev_idx])
85-
if self.policy.use_vec_obs:
86-
obs.append(stored_info.vector_observations[prev_idx])
90+
if stored_step is not None and stored_take_action_outputs is not None:
91+
# We know the step is from the same worker, so use the local agent id.
92+
stored_agent_step = stored_step.get_agent_step_result(local_id)
93+
idx = stored_step.agent_id_to_index[local_id]
94+
obs = stored_agent_step.obs
95+
if not stored_agent_step.done:
8796
if self.policy.use_recurrent:
88-
memory = self.policy.retrieve_memories([agent_id])[0, :]
97+
memory = self.policy.retrieve_memories([global_id])[0, :]
8998
else:
9099
memory = None
91100

92-
done = curr_info.local_done[agent_idx]
93-
max_step = curr_info.max_reached[agent_idx]
101+
done = curr_agent_step.done
102+
max_step = curr_agent_step.max_step
94103

95104
# Add the outputs of the last eval
96-
action = stored_take_action_outputs["action"][prev_idx]
105+
action = stored_take_action_outputs["action"][idx]
97106
if self.policy.use_continuous_act:
98-
action_pre = stored_take_action_outputs["pre_action"][prev_idx]
107+
action_pre = stored_take_action_outputs["pre_action"][idx]
99108
else:
100109
action_pre = None
101-
action_probs = stored_take_action_outputs["log_probs"][prev_idx]
102-
action_masks = stored_info.action_masks[prev_idx]
103-
prev_action = self.policy.retrieve_previous_action([agent_id])[0, :]
110+
action_probs = stored_take_action_outputs["log_probs"][idx]
111+
action_mask = stored_agent_step.action_mask
112+
prev_action = self.policy.retrieve_previous_action([global_id])[
113+
0, :
114+
]
104115

105116
experience = AgentExperience(
106117
obs=obs,
107-
reward=tmp_environment_reward[agent_idx],
118+
reward=curr_agent_step.reward,
108119
done=done,
109120
action=action,
110121
action_probs=action_probs,
111122
action_pre=action_pre,
112-
action_mask=action_masks,
123+
action_mask=action_mask,
113124
prev_action=prev_action,
114125
max_step=max_step,
115126
memory=memory,
116127
)
117128
# Add the value outputs if needed
118-
self.experience_buffers[agent_id].append(experience)
119-
self.episode_rewards[agent_id] += tmp_environment_reward[agent_idx]
129+
self.experience_buffers[global_id].append(experience)
130+
self.episode_rewards[global_id] += curr_agent_step.reward
120131
if (
121-
curr_info.local_done[agent_idx]
132+
curr_agent_step.done
122133
or (
123-
len(self.experience_buffers[agent_id])
134+
len(self.experience_buffers[global_id])
124135
>= self.max_trajectory_length
125136
)
126-
) and len(self.experience_buffers[agent_id]) > 0:
137+
) and len(self.experience_buffers[global_id]) > 0:
127138
# Make next AgentExperience
128-
next_obs = []
129-
for i, _ in enumerate(curr_info.visual_observations):
130-
next_obs.append(curr_info.visual_observations[i][agent_idx])
131-
if self.policy.use_vec_obs:
132-
next_obs.append(curr_info.vector_observations[agent_idx])
139+
next_obs = curr_agent_step.obs
133140
trajectory = Trajectory(
134-
steps=self.experience_buffers[agent_id],
135-
agent_id=agent_id,
141+
steps=self.experience_buffers[global_id],
142+
agent_id=global_id,
136143
next_obs=next_obs,
137144
behavior_id=self.behavior_id,
138145
)
139146
for traj_queue in self.trajectory_queues:
140147
traj_queue.put(trajectory)
141-
self.experience_buffers[agent_id] = []
142-
if curr_info.local_done[agent_idx]:
148+
self.experience_buffers[global_id] = []
149+
if curr_agent_step.done:
143150
self.stats_reporter.add_stat(
144151
"Environment/Cumulative Reward",
145-
self.episode_rewards.get(agent_id, 0),
152+
self.episode_rewards.get(global_id, 0),
146153
)
147154
self.stats_reporter.add_stat(
148155
"Environment/Episode Length",
149-
self.episode_steps.get(agent_id, 0),
156+
self.episode_steps.get(global_id, 0),
150157
)
151-
del self.episode_steps[agent_id]
152-
del self.episode_rewards[agent_id]
153-
elif not curr_info.local_done[agent_idx]:
154-
self.episode_steps[agent_id] += 1
158+
del self.episode_steps[global_id]
159+
del self.episode_rewards[global_id]
160+
elif not curr_agent_step.done:
161+
self.episode_steps[global_id] += 1
155162

156-
self.last_brain_info[agent_id] = curr_info
163+
self.last_step_result[global_id] = batched_step_result
157164

158165
if "action" in take_action_outputs:
159166
self.policy.save_previous_action(
160-
previous_action.agents, take_action_outputs["action"]
167+
previous_action.agent_ids, take_action_outputs["action"]
161168
)
162169

163170
def publish_trajectory_queue(

0 commit comments

Comments
 (0)