Skip to content

Commit 679fc90

Browse files
BowenBaoneginraoofmalfet
authored andcommitted
[ONNX] Support optional type (pytorch#68793) (pytorch#73284)
Summary: Pull Request resolved: pytorch#73284 Some important ops won't support optional type until opset 16, so we can't fully test things end-to-end, but I believe this should be all that's needed. Once ONNX Runtime supports opset 16, we can do more testing and fix any remaining bugs. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D34625646 Pulled By: malfet fbshipit-source-id: 537fcbc1e9d87686cc61f5bd66a997e99cec287b Co-authored-by: BowenBao <[email protected]> Co-authored-by: neginraoof <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> (cherry picked from commit 822e79f)
1 parent b8776e1 commit 679fc90

24 files changed

+1098
-487
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ namespace c10 {
277277
_(onnx, Range) \
278278
_(onnx, Tile) \
279279
_(onnx, Where) \
280+
_(onnx, Optional) \
281+
_(onnx, OptionalGetElement) \
282+
_(onnx, OptionalHasElement) \
280283
FORALL_ATTR_BASE_SYMBOLS(_) \
281284
_(attr, Subgraph) \
282285
_(attr, ReverseSubgraph) \

caffe2/python/onnx/tests/onnx_backend_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@
165165
'|test_optional_.*'
166166
'|test_shape_end_.*'
167167
'|test_shape_start_.*'
168+
'|test_identity_opt_*'
169+
'|test_loop16_seq_none_*'
170+
'|test_if_opt_*'
168171
')')
169172

170173
# Unsupported ops in opset 16

test/onnx/test_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from model_defs.op_test import DummyNet, ConcatNet, PermuteNet, PReluNet, FakeQuantNet
2121
from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2
2222

23-
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, disableScriptTest
23+
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, skipScriptTest
2424

2525
import torch
2626
import torch.onnx
@@ -68,7 +68,7 @@ def test_prelu(self):
6868
)
6969
self.exportTest(PReluNet(), x)
7070

71-
@disableScriptTest()
71+
@skipScriptTest()
7272
def test_concat(self):
7373
input_a = Variable(torch.randn(BATCH_SIZE, 3))
7474
input_b = Variable(torch.randn(BATCH_SIZE, 3))
@@ -79,12 +79,12 @@ def test_permute(self):
7979
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
8080
self.exportTest(PermuteNet(), x)
8181

82-
@disableScriptTest()
82+
@skipScriptTest()
8383
def test_embedding_sequential_1(self):
8484
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
8585
self.exportTest(EmbeddingNetwork1(), x)
8686

87-
@disableScriptTest()
87+
@skipScriptTest()
8888
def test_embedding_sequential_2(self):
8989
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
9090
self.exportTest(EmbeddingNetwork2(), x)
@@ -140,7 +140,7 @@ def test_resnet(self):
140140
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
141141
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
142142

143-
@disableScriptTest() # None type in outputs
143+
@skipScriptTest(min_opset_version=15) # None type in outputs
144144
def test_inception(self):
145145
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
146146
self.exportTest(toC(inception_v3()), toC(x))
@@ -163,14 +163,14 @@ def test_densenet(self):
163163
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
164164
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
165165

166-
@disableScriptTest()
166+
@skipScriptTest()
167167
def test_dcgan_netD(self):
168168
netD = _netD(1)
169169
netD.apply(weights_init)
170170
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
171171
self.exportTest(toC(netD), toC(input))
172172

173-
@disableScriptTest()
173+
@skipScriptTest()
174174
def test_dcgan_netG(self):
175175
netG = _netG(1)
176176
netG.apply(weights_init)
@@ -224,7 +224,7 @@ def test_qat_resnet_per_channel(self):
224224

225225
self.exportTest(toC(qat_resnet50), toC(x))
226226

227-
@disableScriptTest() # None type in outputs
227+
@skipScriptTest(min_opset_version=15) # None type in outputs
228228
def test_googlenet(self):
229229
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
230230
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
@@ -237,7 +237,7 @@ def test_mobilenet(self):
237237
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
238238
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
239239

240-
@disableScriptTest() # prim_data
240+
@skipScriptTest() # prim_data
241241
def test_shufflenet(self):
242242
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
243243
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)

test/onnx/test_pytorch_common.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,11 @@ def wrapper(self):
7575
return wrapper
7676
return skip_dec
7777

78-
# Enables tests for scripting, instead of only tracing the model.
79-
def enableScriptTest():
78+
# skips tests for scripting.
79+
def skipScriptTest(min_opset_version=float("inf")):
8080
def script_dec(func):
8181
def wrapper(self):
82-
self.is_script_test_enabled = True
83-
return func(self)
84-
return wrapper
85-
return script_dec
86-
87-
88-
# Disable tests for scripting.
89-
def disableScriptTest():
90-
def script_dec(func):
91-
def wrapper(self):
92-
self.is_script_test_enabled = False
82+
self.is_script_test_enabled = self.opset_version >= min_opset_version
9383
return func(self)
9484
return wrapper
9585
return script_dec

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,11 +1997,11 @@ def forward(self, lstm_in):
19971997
bias=has_bias,
19981998
num_layers=num_layers,
19991999
)
2000-
lstm_in = [
2000+
lstm_in = ([
20012001
torch.from_numpy(inputs),
20022002
torch.from_numpy(hx),
20032003
torch.from_numpy(hx),
2004-
] + [param.detach() for param in torch_lstm._flat_weights]
2004+
] + [param.detach() for param in torch_lstm._flat_weights],)
20052005

20062006
self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)
20072007

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Owner(s): ["module: onnx"]
2+
3+
"""Tests for onnx export that don't run the exported model."""
4+
5+
import io
6+
import unittest
7+
8+
import onnx
9+
import torch
10+
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
11+
from torch import Tensor
12+
from torch.onnx import symbolic_helper
13+
14+
from typing import Optional, Type
15+
16+
17+
class TestOptionalOutput(unittest.TestCase):
18+
# TODO: Move these tests to test_pytorch_onnx_onnxruntime once
19+
# ONNX Runtime 1.11 is released and supports opset 16.
20+
21+
class IfNoneInput(torch.nn.Module):
22+
def forward(self, x) -> Optional[Tensor]:
23+
y: Optional[Tensor] = None
24+
if x.size(0) > 1:
25+
y = x
26+
return y
27+
28+
class IfNoneOutput(torch.nn.Module):
29+
def forward(self, x) -> Optional[Tensor]:
30+
y: Optional[Tensor] = x
31+
if x.size(0) > 1:
32+
y = None
33+
return y
34+
35+
36+
class LoopNoneInput(torch.nn.Module):
37+
def forward(self, x) -> Optional[Tensor]:
38+
y: Optional[Tensor] = None
39+
for _ in range(x.size(0)):
40+
y = x
41+
return y
42+
43+
class LoopNoneOutput(torch.nn.Module):
44+
def forward(self, x) -> Optional[Tensor]:
45+
y: Optional[Tensor] = x
46+
for _ in range(x.size(0)):
47+
y = None
48+
return y
49+
50+
51+
@parametrize(
52+
"module_class",
53+
(IfNoneInput, IfNoneOutput, LoopNoneInput, LoopNoneOutput),
54+
name_fn=lambda module_class: module_class.__name__)
55+
@parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
56+
def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
57+
# Need scripting to preserve control flow for this test to be meaningful.
58+
model = torch.jit.script(module_class())
59+
f = io.BytesIO()
60+
x = torch.ones(x_size)
61+
dynamic_axis_name = "condition"
62+
torch.onnx.export(
63+
model, (x,), f, opset_version=15,
64+
# Ensure condition is not constant
65+
dynamic_axes={"x": {0: dynamic_axis_name}}, input_names=["x"])
66+
exported = onnx.load_from_string(f.getvalue())
67+
expected_elem_type = symbolic_helper.scalar_type_to_onnx[
68+
symbolic_helper.scalar_type_to_pytorch_type.index(x.dtype)].value
69+
expected_output_type = onnx.helper.make_optional_type_proto(
70+
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)))
71+
self.assertEqual(expected_output_type, exported.graph.output[0].type)
72+
for node in exported.graph.node:
73+
# Both branches output types should match.
74+
if node.op_type == "If":
75+
for attr in node.attribute:
76+
if attr.name in ("then_branch", "else_branch"):
77+
self.assertEqual(expected_output_type, attr.g.output[0].type)
78+
79+
def test_uninitialized_optional(self):
80+
class Module(torch.nn.Module):
81+
def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
82+
if y is not None:
83+
if y.shape[1] < 5:
84+
if y.size(0) == 1:
85+
y = y + 4
86+
else:
87+
return y
88+
return y
89+
90+
y = torch.ones((3, 4), dtype=torch.int)
91+
torch.onnx.export(
92+
torch.jit.script(Module()), y, io.BytesIO(), opset_version=15,
93+
dynamic_axes={"y": {0: "y0", 1: "y1"}}, input_names=["y"])
94+
95+
96+
instantiate_parametrized_tests(TestOptionalOutput)
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

0 commit comments

Comments
 (0)