Skip to content

Commit 66b427a

Browse files
committed
Revert "[ONNX] Support optional type (pytorch#68793) (pytorch#73284)"
This reverts commit 679fc90.
1 parent e06400e commit 66b427a

23 files changed

+533
-1167
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,6 @@ namespace c10 {
279279
_(onnx, Range) \
280280
_(onnx, Tile) \
281281
_(onnx, Where) \
282-
_(onnx, Optional) \
283-
_(onnx, OptionalGetElement) \
284-
_(onnx, OptionalHasElement) \
285282
FORALL_ATTR_BASE_SYMBOLS(_) \
286283
_(attr, Subgraph) \
287284
_(attr, ReverseSubgraph) \

caffe2/python/onnx/tests/onnx_backend_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,6 @@
165165
'|test_optional_.*'
166166
'|test_shape_end_.*'
167167
'|test_shape_start_.*'
168-
'|test_identity_opt_*'
169-
'|test_loop16_seq_none_*'
170-
'|test_if_opt_*'
171168
')')
172169

173170
# Unsupported ops in opset 16

test/onnx/test_models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
run_tests,
2121
skipIfNoLapack,
2222
skipIfUnsupportedMinOpsetVersion,
23-
skipScriptTest,
23+
disableScriptTest,
2424
)
2525
from torchvision.models import shufflenet_v2_x1_0
2626
from torchvision.models.alexnet import alexnet
@@ -82,7 +82,7 @@ def test_prelu(self):
8282
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
8383
self.exportTest(PReluNet(), x)
8484

85-
@skipScriptTest()
85+
@disableScriptTest()
8686
def test_concat(self):
8787
input_a = Variable(torch.randn(BATCH_SIZE, 3))
8888
input_b = Variable(torch.randn(BATCH_SIZE, 3))
@@ -93,12 +93,12 @@ def test_permute(self):
9393
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
9494
self.exportTest(PermuteNet(), x)
9595

96-
@skipScriptTest()
96+
@disableScriptTest()
9797
def test_embedding_sequential_1(self):
9898
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
9999
self.exportTest(EmbeddingNetwork1(), x)
100100

101-
@skipScriptTest()
101+
@disableScriptTest()
102102
def test_embedding_sequential_2(self):
103103
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
104104
self.exportTest(EmbeddingNetwork2(), x)
@@ -152,7 +152,7 @@ def test_resnet(self):
152152
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
153153
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
154154

155-
@skipScriptTest(min_opset_version=15) # None type in outputs
155+
@disableScriptTest() # None type in outputs
156156
def test_inception(self):
157157
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
158158
self.exportTest(toC(inception_v3()), toC(x))
@@ -175,14 +175,14 @@ def test_densenet(self):
175175
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
176176
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
177177

178-
@skipScriptTest()
178+
@disableScriptTest()
179179
def test_dcgan_netD(self):
180180
netD = _netD(1)
181181
netD.apply(weights_init)
182182
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
183183
self.exportTest(toC(netD), toC(input))
184184

185-
@skipScriptTest()
185+
@disableScriptTest()
186186
def test_dcgan_netG(self):
187187
netG = _netG(1)
188188
netG.apply(weights_init)
@@ -239,7 +239,7 @@ def test_qat_resnet_per_channel(self):
239239

240240
self.exportTest(toC(qat_resnet50), toC(x))
241241

242-
@skipScriptTest(min_opset_version=15) # None type in outputs
242+
@disableScriptTest() # None type in outputs
243243
def test_googlenet(self):
244244
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
245245
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
@@ -252,7 +252,7 @@ def test_mobilenet(self):
252252
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
253253
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
254254

255-
@skipScriptTest() # prim_data
255+
@disableScriptTest() # prim_data
256256
def test_shufflenet(self):
257257
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
258258
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)

test/onnx/test_pytorch_common.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,21 @@ def wrapper(self):
9191

9292
return skip_dec
9393

94+
# Enables tests for scripting, instead of only tracing the model.
95+
def enableScriptTest():
96+
def script_dec(func):
97+
def wrapper(self):
98+
self.is_script_test_enabled = True
99+
return func(self)
100+
return wrapper
101+
return script_dec
102+
94103

95-
# skips tests for scripting.
96-
def skipScriptTest(min_opset_version=float("inf")):
104+
# Disable tests for scripting.
105+
def disableScriptTest():
97106
def script_dec(func):
98107
def wrapper(self):
99-
self.is_script_test_enabled = self.opset_version >= min_opset_version
108+
self.is_script_test_enabled = False
100109
return func(self)
101110

102111
return wrapper

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,14 +2573,11 @@ def forward(self, lstm_in):
25732573
bias=has_bias,
25742574
num_layers=num_layers,
25752575
)
2576-
lstm_in = (
2577-
[
2578-
torch.from_numpy(inputs),
2579-
torch.from_numpy(hx),
2580-
torch.from_numpy(hx),
2581-
]
2582-
+ [param.detach() for param in torch_lstm._flat_weights],
2583-
)
2576+
lstm_in = [
2577+
torch.from_numpy(inputs),
2578+
torch.from_numpy(hx),
2579+
torch.from_numpy(hx),
2580+
] + [param.detach() for param in torch_lstm._flat_weights]
25842581

25852582
self.run_model_test(
25862583
MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False

0 commit comments

Comments
 (0)