Skip to content

Commit fcc36de

Browse files
BowenBaopytorchmergebot
authored andcommitted
[ONNX][dynamo_export] Turn off opmath type promotion for div (pytorch#119112)
Skip opmath promotion for `_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` as well. Fixes pytorch#118941 Pull Request resolved: pytorch#119112 Approved by: https://github.com/thiagocrepaldi
1 parent 45a7932 commit fcc36de

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

test/onnx/test_fx_op_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
19271927
"linalg.multi_dot": [3e-2, 1e-3],
19281928
"linalg.vecdot": [1e-2, 2e-2],
19291929
"linspace": [2e-2, 2e-3],
1930+
"masked.var": [2e-2, 1e-3],
19301931
"matmul": [2e-2, 6e-2],
19311932
"nn.functional.batch_norm": [3e-2, 1e-3],
19321933
"nn.functional.binary_cross_entropy": [3e-2, 1e-3],

test/onnx/test_fx_to_onnx.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Owner(s): ["module: onnx"]
22
from __future__ import annotations
33

4-
import io
5-
64
import tempfile
75

86
from typing import Mapping, Tuple
97

108
import onnx
9+
import onnx.inliner
1110
import pytorch_test_common
1211
import torch
1312
import transformers # type: ignore[import]
@@ -621,17 +620,22 @@ def forward(self, x):
621620
x,
622621
)
623622

624-
def test_aten_linalg_vector_norm_with_reducel2(self):
625-
class Net(nn.Module):
626-
def forward(self, x):
627-
x = F.normalize(x)
628-
return x
629-
630-
f = io.BytesIO()
631-
torch.onnx.export(Net(), (torch.randn(1, 2, 2),), f)
632-
onnx_model = onnx.load_from_string(f.getvalue())
633-
onnx_nodes = [n.op_type for n in onnx_model.graph.node]
634-
self.assertTrue("ReduceL2" in onnx_nodes)
623+
def test_aten_div_no_opmath_type_promotion(self):
624+
class Model(torch.nn.Module):
625+
def forward(self, input):
626+
return input / 2
627+
628+
model = Model()
629+
input = torch.randn(3, 5, requires_grad=True, dtype=torch.float16)
630+
631+
model_proto = torch.onnx.dynamo_export(model, input).model_proto
632+
model_proto = onnx.inliner.inline_local_functions(model_proto)
633+
div_node = next(
634+
node for node in model_proto.graph.node if node.op_type == "Div"
635+
)
636+
# The input of Div node should be the input of the model,
637+
# with no Cast node in between.
638+
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
635639

636640
def test_exported_program_as_input_with_model_signature(self):
637641
class Model(torch.nn.Module):

test/onnx/test_pytorch_onnx_no_runtime.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,18 @@ def test_aten_device_with_index(self):
13141314
decoder_attention_mask=ids["attention_mask"],
13151315
)
13161316

1317+
def test_aten_linalg_vector_norm_with_reducel2(self):
1318+
class Net(torch.nn.Module):
1319+
def forward(self, x):
1320+
x = F.normalize(x)
1321+
return x
1322+
1323+
f = io.BytesIO()
1324+
torch.onnx.export(Net(), (torch.randn(1, 2, 2),), f)
1325+
onnx_model = onnx.load_from_string(f.getvalue())
1326+
onnx_nodes = [n.op_type for n in onnx_model.graph.node]
1327+
self.assertIn("ReduceL2", onnx_nodes)
1328+
13171329

13181330
if __name__ == "__main__":
13191331
common_utils.run_tests()

torch/onnx/_internal/fx/passes/type_promotion.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,9 @@ def _consolidate_input_dtype(
184184
since there is no way to differentiate between inserted upcasts and model code
185185
casts. Hence we consolidate the input dtype to the result dtype to avoid this.
186186
"""
187-
if (
188-
not self._USE_OPMATH
189-
and self.promotion_kind
190-
== _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
187+
if not self._USE_OPMATH and self.promotion_kind in (
188+
_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
189+
_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
191190
):
192191
return result_dtype
193192
return computed_dtype

0 commit comments

Comments
 (0)