Skip to content

Commit 38c1007

Browse files
authored
Fix export
2 parents 6f46b30 + 1a15577 commit 38c1007

File tree

4 files changed

+34
-42
lines changed

4 files changed

+34
-42
lines changed

ml-agents/mlagents/trainers/tests/torch/test_networks.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,16 @@ def test_simple_actor(action_type):
150150
assert act.shape == (1, 1)
151151

152152
# Test forward
153-
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
153+
actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
154154
[sample_obs], [], masks=masks
155155
)
156156
for act in actions:
157+
# This is different from above for ONNX export
157158
if action_type == ActionType.CONTINUOUS:
158-
assert act.shape == (
159-
act_size[0],
160-
1,
161-
) # This is different from above for ONNX export
159+
assert act.shape == (act_size[0], 1)
162160
else:
163-
assert act.shape == (1, 1)
161+
assert act.shape == tuple(act_size)
164162

165-
# TODO: Once export works properly. fix the shapes here.
166163
assert mem_size == 0
167164
assert is_cont == int(action_type == ActionType.CONTINUOUS)
168165
assert act_size_vec == torch.tensor(act_size)

ml-agents/mlagents/trainers/torch/layers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def memory_size(self) -> int:
142142
def forward(
143143
self, input_tensor: torch.Tensor, memories: torch.Tensor
144144
) -> Tuple[torch.Tensor, torch.Tensor]:
145-
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
145+
# We don't use torch.split here since it is not supported by Barracuda
146+
h0 = memories[:, :, : self.hidden_size]
147+
c0 = memories[:, :, self.hidden_size :]
146148
hidden = (h0, c0)
147149
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
148150
output_mem = torch.cat(hidden_out, dim=-1)

ml-agents/mlagents/trainers/torch/model_serialization.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,45 @@
1010

1111
class ModelSerializer:
1212
def __init__(self, policy):
13+
# ONNX only support input in NCHW (channel first) format.
14+
# Barracuda also expect to get data in NCHW.
15+
# Any multi-dimentional input should follow that otherwise will
16+
# cause problem to barracuda import.
1317
self.policy = policy
1418
batch_dim = [1]
19+
seq_len_dim = [1]
1520
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
21+
# create input shape of NCHW
22+
# (It's NHWC in self.policy.behavior_spec.observation_shapes)
1623
dummy_vis_obs = [
17-
torch.zeros(batch_dim + list(shape))
24+
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]])
1825
for shape in self.policy.behavior_spec.observation_shapes
1926
if len(shape) == 3
2027
]
2128
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
22-
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size])
29+
dummy_memories = torch.zeros(
30+
batch_dim + seq_len_dim + [self.policy.export_memory_size]
31+
)
2332

24-
# Need to pass all possible inputs since currently keyword arguments is not
25-
# supported by torch.nn.export()
2633
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
2734

28-
# Input names can only contain actual input used since in torch.nn.export
29-
# it maps input_names only to input nodes that exist in the graph
30-
self.input_names = []
31-
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}}
32-
if self.policy.use_vec_obs:
33-
self.input_names.append("vector_observation")
34-
self.dynamic_axes.update({"vector_observation": {0: "batch"}})
35-
for i in range(self.policy.vis_obs_size):
36-
self.input_names.append(f"visual_observation_{i}")
37-
self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}})
38-
if not self.policy.use_continuous_act:
39-
self.input_names.append("action_masks")
40-
self.dynamic_axes.update({"action_masks": {0: "batch"}})
41-
if self.policy.use_recurrent:
42-
self.input_names.append("memories")
43-
self.dynamic_axes.update({"memories": {0: "batch"}})
35+
self.input_names = (
36+
["vector_observation"]
37+
+ [f"visual_observation_{i}" for i in range(self.policy.vis_obs_size)]
38+
+ ["action_masks", "memories"]
39+
)
4440

4541
self.output_names = [
4642
"action",
47-
"action_probs",
4843
"version_number",
4944
"memory_size",
5045
"is_continuous_control",
5146
"action_output_shape",
5247
]
5348

49+
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
50+
self.dynamic_axes.update({"action": {0: "batch"}})
51+
5452
def export_policy_model(self, output_filepath: str) -> None:
5553
"""
5654
Exports a Torch model for a Policy to .onnx format for Unity embedding.
@@ -68,7 +66,6 @@ def export_policy_model(self, output_filepath: str) -> None:
6866
self.policy.actor_critic,
6967
self.dummy_input,
7068
onnx_output_path,
71-
verbose=False,
7269
opset_version=SerializationSettings.onnx_opset,
7370
input_names=self.input_names,
7471
output_names=self.output_names,

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def forward(
8686

8787
for idx, encoder in enumerate(self.visual_encoders):
8888
vis_input = vis_inputs[idx]
89-
vis_input = vis_input.permute([0, 3, 1, 2])
89+
if not torch.onnx.is_in_onnx_export():
90+
vis_input = vis_input.permute([0, 3, 1, 2])
9091
hidden = encoder(vis_input)
9192
encodes.append(hidden)
9293

@@ -192,8 +193,7 @@ def forward(
192193
vis_inputs: List[torch.Tensor],
193194
masks: Optional[torch.Tensor] = None,
194195
memories: Optional[torch.Tensor] = None,
195-
sequence_length: int = 1,
196-
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
196+
) -> Tuple[torch.Tensor, int, int, int, int]:
197197
"""
198198
Forward pass of the Actor for inference. This is required for export to ONNX, and
199199
the inputs and outputs of this method should not be changed without a respective change
@@ -325,23 +325,19 @@ def forward(
325325
vis_inputs: List[torch.Tensor],
326326
masks: Optional[torch.Tensor] = None,
327327
memories: Optional[torch.Tensor] = None,
328-
sequence_length: int = 1,
329-
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
328+
) -> Tuple[torch.Tensor, int, int, int, int]:
330329
"""
331330
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
332331
"""
333-
dists, _ = self.get_dists(
334-
vec_inputs, vis_inputs, masks, memories, sequence_length
335-
)
332+
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
336333
action_list = self.sample_action(dists)
337334
sampled_actions = torch.stack(action_list, dim=-1)
338335
if self.act_type == ActionType.CONTINUOUS:
339-
log_probs = dists[0].log_prob(sampled_actions)
336+
action_out = sampled_actions
340337
else:
341-
log_probs = dists[0].all_log_prob()
338+
action_out = dists[0].all_log_prob()
342339
return (
343-
sampled_actions,
344-
log_probs,
340+
action_out,
345341
self.version_number,
346342
torch.Tensor([self.network_body.memory_size]),
347343
self.is_continuous_int,

0 commit comments

Comments
 (0)