Skip to content

Commit 82a5258

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] Relax node constraint for onnx shape inference (#77379) (#77379)
Summary: None as input is legal per ONNX spec for representing optional inputs. For [example](https://github.com/onnx/onnx/blob/main/docs/Operators.md#inputs-2---3-7) `constant_value` for `ONNX::Pad`. This PR removes such constraint check that was set prior to calling onnx shape inference. For the issue below, such constraint prevents the onnx shape inference of `ONNX::Pad`, which leads to falling back on an incorrect constant traced shape. For the unit test in this PR, prior to this PR, the ONNX shape inference for `ONNX::Pad` would be skipped, and would return `None` instead. Fixes pytorch/vision#5971 Pull Request resolved: #77379 Approved by: https://github.com/garymm Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a812c4cd96d94d51627d2af290ae87de34169ec0 Reviewed By: atalman Differential Revision: D36378914 fbshipit-source-id: a11be0f9666dd637490db80725f7021b328c9f27
1 parent a93c044 commit 82a5258

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

test/onnx/test_pytorch_onnx_shape_inference.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,15 @@ def test_expand(self):
201201
expand = g.op("Expand", constant, shape)
202202
self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
203203

204+
def test_pad(self):
205+
g = self.create_empty_graph()
206+
input = g.addInput()
207+
input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100]))
208+
constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long))
209+
none = g.op("prim::Constant").setType(torch.NoneType.get())
210+
pad = g.op("Pad", input, constant, none, mode_s="constant")
211+
self.run_test(g, pad.node(), expect_tensor("Float", shape=(None, None, None)))
212+
204213

205214
if __name__ == "__main__":
206215
unittest.main()

torch/csrc/jit/passes/onnx/shape_type_inference.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,6 @@ bool IsValidONNXNode(const Node* n) {
226226
}
227227
}
228228

229-
for (auto inp : n->inputs()) {
230-
if (inp->type() == NoneType::get()) {
231-
return false;
232-
}
233-
}
234-
235229
return true;
236230
}
237231

0 commit comments

Comments
 (0)