Skip to content

Commit 327f340

Browse files
committed
[ONNX] Update typing and error messages in symbolic_helper
### Description - Clearer error messages with more context - Created `SymbolicValueError` which adds context of the value to the error message - Type annotation example error message: ``` torch.onnx.errors.SymbolicValueError: ONNX symbolic does not understand the Constant node '%1 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 3 3 [ CPULongType{2} ]]() ' specified with descriptor 'is'. [Caused by the value '1 defined in (%1 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 3 3 [ CPULongType{2} ]]() )' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Constant'.] Inputs: Empty Outputs: #0: 1 defined in (%1 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 3 3 [ CPULongType{2} ]]() ) (type 'Tensor') ``` ### Issue - pytorch#77316 (Runtime error during symbolic conversion) ### Testing Unit tested ghstack-source-id: 5ffb9ca Pull Request resolved: pytorch#83007
1 parent 3afa4d3 commit 327f340

File tree

6 files changed

+221
-98
lines changed

6 files changed

+221
-98
lines changed

test/onnx/test_pytorch_onnx_no_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def forward(self, x):
167167
tm = TraceMe()
168168
tm = torch.jit.trace(tm, torch.rand(3, 4))
169169
f = io.BytesIO()
170-
torch.onnx._export(tm, (torch.rand(3, 4),), f)
170+
torch.onnx.export(tm, (torch.rand(3, 4),), f)
171171

172172
def test_export_tensoroption_to(self):
173173
def foo(x):

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,8 +1321,10 @@ class TensorType(JitType):
13211321
def getInferred(cls) -> TensorType: ...
13221322
def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ...
13231323
def sizes(self) -> Optional[List[_int]]: ...
1324+
def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
13241325
def strides(self) -> Optional[List[_int]]: ...
13251326
def device(self) -> Optional[_device]: ...
1327+
def dim(self) -> _int: ...
13261328
def dtype(self) -> Optional[_dtype]: ...
13271329
@staticmethod
13281330
def create_from_tensor(t: Tensor) -> TensorType: ...

torch/onnx/errors.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
"""ONNX exporter exceptions."""
2+
from __future__ import annotations
23

4+
import textwrap
35
from typing import Optional
46

7+
from torch import _C
58
from torch.onnx import _constants
69

7-
__all__ = ["OnnxExporterError", "CheckerError", "UnsupportedOperatorError"]
10+
__all__ = [
11+
"OnnxExporterError",
12+
"CheckerError",
13+
"UnsupportedOperatorError",
14+
"SymbolicValueError",
15+
]
816

917

1018
class OnnxExporterError(RuntimeError):
@@ -14,7 +22,7 @@ class OnnxExporterError(RuntimeError):
1422

1523

1624
class CheckerError(OnnxExporterError):
17-
r"""Raised when ONNX checker detects an invalid model."""
25+
"""Raised when ONNX checker detects an invalid model."""
1826

1927
pass
2028

@@ -42,3 +50,50 @@ def __init__(
4250
"it with the right domain and version."
4351
)
4452
super().__init__(msg)
53+
54+
55+
class SymbolicValueError(OnnxExporterError):
56+
"""Errors around TorchScript values and nodes."""
57+
58+
def __init__(self, msg: str, value: _C.Value):
59+
message = (
60+
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
61+
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
62+
)
63+
64+
code_location = value.node().sourceRange()
65+
if code_location:
66+
message += f"\n (node defined in {code_location})"
67+
68+
try:
69+
# Add its input and output to the message.
70+
message += "\n\n"
71+
message += textwrap.indent(
72+
(
73+
"Inputs:\n"
74+
+ (
75+
"\n".join(
76+
f" #{i}: {input_} (type '{input_.type()}')"
77+
for i, input_ in enumerate(value.node().inputs())
78+
)
79+
or " Empty"
80+
)
81+
+ "\n"
82+
+ "Outputs:\n"
83+
+ (
84+
"\n".join(
85+
f" #{i}: {output} (type '{output.type()}')"
86+
for i, output in enumerate(value.node().outputs())
87+
)
88+
or " Empty"
89+
)
90+
),
91+
" ",
92+
)
93+
except AttributeError:
94+
message += (
95+
" Failed to obtain its input and output for debugging. "
96+
"Please refer to the TorchScript graph for debugging information."
97+
)
98+
99+
super().__init__(message)

0 commit comments

Comments
 (0)