1
1
import sys
2
+ import numpy as np
2
3
from typing import List , Dict , Deque , TypeVar , Generic
3
4
from collections import defaultdict , Counter , deque
4
5
6
+ from mlagents_envs .base_env import BatchedStepResult
5
7
from mlagents .trainers .trajectory import Trajectory , AgentExperience
6
- from mlagents .trainers .brain import BrainInfo
7
8
from mlagents .trainers .tf_policy import TFPolicy
8
9
from mlagents .trainers .policy import Policy
9
10
from mlagents .trainers .action_info import ActionInfo , ActionInfoOutputs
10
11
from mlagents .trainers .stats import StatsReporter
12
+ from mlagents .trainers .env_manager import get_global_agent_id
11
13
12
14
T = TypeVar ("T" )
13
15
@@ -35,7 +37,7 @@ def __init__(
35
37
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
36
38
"""
37
39
self .experience_buffers : Dict [str , List [AgentExperience ]] = defaultdict (list )
38
- self .last_brain_info : Dict [str , BrainInfo ] = {}
40
+ self .last_step_result : Dict [str , BatchedStepResult ] = {}
39
41
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
40
42
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
41
43
self .last_take_action_outputs : Dict [str , ActionInfoOutputs ] = {}
@@ -50,12 +52,15 @@ def __init__(
50
52
self .behavior_id = behavior_id
51
53
52
54
def add_experiences (
53
- self , curr_info : BrainInfo , previous_action : ActionInfo
55
+ self ,
56
+ batched_step_result : BatchedStepResult ,
57
+ worker_id : int ,
58
+ previous_action : ActionInfo ,
54
59
) -> None :
55
60
"""
56
61
Adds experiences to each agent's experience history.
57
- :param curr_info : current BrainInfo .
58
- :param previous_action: The return value of the Policy's get_action method.
62
+ :param batched_step_result : current BatchedStepResult .
63
+ :param previous_action: The outputs of the Policy's get_action method.
59
64
"""
60
65
take_action_outputs = previous_action .outputs
61
66
if take_action_outputs :
@@ -65,99 +70,101 @@ def add_experiences(
65
70
"Policy/Learning Rate" , take_action_outputs ["learning_rate" ]
66
71
)
67
72
68
- for agent_id in previous_action .agents :
69
- self .last_take_action_outputs [agent_id ] = take_action_outputs
70
-
71
- # Store the environment reward
72
- tmp_environment_reward = curr_info .rewards
73
-
74
- for agent_idx , agent_id in enumerate (curr_info .agents ):
75
- stored_info = self .last_brain_info .get (agent_id , None )
73
+ # Make unique agent_ids that are global across workers
74
+ action_global_agent_ids = [
75
+ get_global_agent_id (worker_id , ag_id ) for ag_id in previous_action .agent_ids
76
+ ]
77
+ for global_id in action_global_agent_ids :
78
+ self .last_take_action_outputs [global_id ] = take_action_outputs
79
+
80
+ for _id in np .nditer (batched_step_result .agent_id ): # Explicit numpy iteration
81
+ local_id = int (
82
+ _id
83
+ ) # Needed for mypy to pass since ndarray has no content type
84
+ curr_agent_step = batched_step_result .get_agent_step_result (local_id )
85
+ global_id = get_global_agent_id (worker_id , local_id )
86
+ stored_step = self .last_step_result .get (global_id , None )
76
87
stored_take_action_outputs = self .last_take_action_outputs .get (
77
- agent_id , None
88
+ global_id , None
78
89
)
79
- if stored_info is not None and stored_take_action_outputs is not None :
80
- prev_idx = stored_info .agents .index (agent_id )
81
- obs = []
82
- if not stored_info .local_done [prev_idx ]:
83
- for i , _ in enumerate (stored_info .visual_observations ):
84
- obs .append (stored_info .visual_observations [i ][prev_idx ])
85
- if self .policy .use_vec_obs :
86
- obs .append (stored_info .vector_observations [prev_idx ])
90
+ if stored_step is not None and stored_take_action_outputs is not None :
91
+ # We know the step is from the same worker, so use the local agent id.
92
+ stored_agent_step = stored_step .get_agent_step_result (local_id )
93
+ idx = stored_step .agent_id_to_index [local_id ]
94
+ obs = stored_agent_step .obs
95
+ if not stored_agent_step .done :
87
96
if self .policy .use_recurrent :
88
- memory = self .policy .retrieve_memories ([agent_id ])[0 , :]
97
+ memory = self .policy .retrieve_memories ([global_id ])[0 , :]
89
98
else :
90
99
memory = None
91
100
92
- done = curr_info . local_done [ agent_idx ]
93
- max_step = curr_info . max_reached [ agent_idx ]
101
+ done = curr_agent_step . done
102
+ max_step = curr_agent_step . max_step
94
103
95
104
# Add the outputs of the last eval
96
- action = stored_take_action_outputs ["action" ][prev_idx ]
105
+ action = stored_take_action_outputs ["action" ][idx ]
97
106
if self .policy .use_continuous_act :
98
- action_pre = stored_take_action_outputs ["pre_action" ][prev_idx ]
107
+ action_pre = stored_take_action_outputs ["pre_action" ][idx ]
99
108
else :
100
109
action_pre = None
101
- action_probs = stored_take_action_outputs ["log_probs" ][prev_idx ]
102
- action_masks = stored_info .action_masks [prev_idx ]
103
- prev_action = self .policy .retrieve_previous_action ([agent_id ])[0 , :]
110
+ action_probs = stored_take_action_outputs ["log_probs" ][idx ]
111
+ action_mask = stored_agent_step .action_mask
112
+ prev_action = self .policy .retrieve_previous_action ([global_id ])[
113
+ 0 , :
114
+ ]
104
115
105
116
experience = AgentExperience (
106
117
obs = obs ,
107
- reward = tmp_environment_reward [ agent_idx ] ,
118
+ reward = curr_agent_step . reward ,
108
119
done = done ,
109
120
action = action ,
110
121
action_probs = action_probs ,
111
122
action_pre = action_pre ,
112
- action_mask = action_masks ,
123
+ action_mask = action_mask ,
113
124
prev_action = prev_action ,
114
125
max_step = max_step ,
115
126
memory = memory ,
116
127
)
117
128
# Add the value outputs if needed
118
- self .experience_buffers [agent_id ].append (experience )
119
- self .episode_rewards [agent_id ] += tmp_environment_reward [ agent_idx ]
129
+ self .experience_buffers [global_id ].append (experience )
130
+ self .episode_rewards [global_id ] += curr_agent_step . reward
120
131
if (
121
- curr_info . local_done [ agent_idx ]
132
+ curr_agent_step . done
122
133
or (
123
- len (self .experience_buffers [agent_id ])
134
+ len (self .experience_buffers [global_id ])
124
135
>= self .max_trajectory_length
125
136
)
126
- ) and len (self .experience_buffers [agent_id ]) > 0 :
137
+ ) and len (self .experience_buffers [global_id ]) > 0 :
127
138
# Make next AgentExperience
128
- next_obs = []
129
- for i , _ in enumerate (curr_info .visual_observations ):
130
- next_obs .append (curr_info .visual_observations [i ][agent_idx ])
131
- if self .policy .use_vec_obs :
132
- next_obs .append (curr_info .vector_observations [agent_idx ])
139
+ next_obs = curr_agent_step .obs
133
140
trajectory = Trajectory (
134
- steps = self .experience_buffers [agent_id ],
135
- agent_id = agent_id ,
141
+ steps = self .experience_buffers [global_id ],
142
+ agent_id = global_id ,
136
143
next_obs = next_obs ,
137
144
behavior_id = self .behavior_id ,
138
145
)
139
146
for traj_queue in self .trajectory_queues :
140
147
traj_queue .put (trajectory )
141
- self .experience_buffers [agent_id ] = []
142
- if curr_info . local_done [ agent_idx ] :
148
+ self .experience_buffers [global_id ] = []
149
+ if curr_agent_step . done :
143
150
self .stats_reporter .add_stat (
144
151
"Environment/Cumulative Reward" ,
145
- self .episode_rewards .get (agent_id , 0 ),
152
+ self .episode_rewards .get (global_id , 0 ),
146
153
)
147
154
self .stats_reporter .add_stat (
148
155
"Environment/Episode Length" ,
149
- self .episode_steps .get (agent_id , 0 ),
156
+ self .episode_steps .get (global_id , 0 ),
150
157
)
151
- del self .episode_steps [agent_id ]
152
- del self .episode_rewards [agent_id ]
153
- elif not curr_info . local_done [ agent_idx ] :
154
- self .episode_steps [agent_id ] += 1
158
+ del self .episode_steps [global_id ]
159
+ del self .episode_rewards [global_id ]
160
+ elif not curr_agent_step . done :
161
+ self .episode_steps [global_id ] += 1
155
162
156
- self .last_brain_info [ agent_id ] = curr_info
163
+ self .last_step_result [ global_id ] = batched_step_result
157
164
158
165
if "action" in take_action_outputs :
159
166
self .policy .save_previous_action (
160
- previous_action .agents , take_action_outputs ["action" ]
167
+ previous_action .agent_ids , take_action_outputs ["action" ]
161
168
)
162
169
163
170
def publish_trajectory_queue (
0 commit comments