Skip to content

Commit f148e4c

Browse files
apbosegs-olive
authored andcommitted
reciprocal lowering pass
1 parent 5bd321c commit f148e4c

File tree

5 files changed

+168
-7
lines changed

5 files changed

+168
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1010
from torch_tensorrt.dynamo.conversion import SourceIR, impl
11+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
12+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor
1113

1214
_LOGGER: logging.Logger = logging.getLogger(__name__)
1315

@@ -54,6 +56,20 @@ def aten_ops_div(
5456
"input": args[0],
5557
"other": args[1],
5658
}
59+
# If both are TRTTensor, both are cast to float32
60+
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
61+
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
62+
network, kwargs_new["input"], kwargs_new["other"]
63+
)
64+
# If one is TRTTensor, it is cast to float32
65+
elif isinstance(args[0], TRTTensor) and (
66+
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
67+
):
68+
kwargs_new["input"] = cast_trt_tensor(network, kwargs_new["input"], trt.float32)
69+
elif isinstance(args[1], TRTTensor) and (
70+
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
71+
):
72+
kwargs_new["other"] = cast_trt_tensor(network, kwargs_new["other"], trt.float32)
5773
rounding_mode = kwargs.get("rounding_mode")
5874
if rounding_mode is None:
5975
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
TRTNetwork,
66
TRTTensor,
77
)
8+
import torch_tensorrt as trt
9+
from typing import List
810

911

1012
def dynamic_unsupported(node: torch.fx.Node) -> bool:
@@ -65,6 +67,28 @@ def cast_trt_tensor(
6567
return input_val
6668

6769

70+
def cast_int_int_div_trt_tensor(
71+
network: TRTNetwork,
72+
lhs_val: TRTTensor,
73+
rhs_val: TRTTensor,
74+
) -> List[TRTTensor]:
75+
"""
76+
Given two `int` data type TRT Tensor to div operation, cast the TRT Tensor to float type
77+
Args:
78+
network (TRTNetwork): A TensorRT network
79+
lhs_val (TRTTensor): A TRT Tensor numerator
80+
rhs_val (TRTTensor): A TRT Tensor numerator
81+
Returns:
82+
A list of lhs_val and rhs_val casted to the approriate datatype
83+
"""
84+
if (lhs_val.dtype == trt.int8 or lhs_val.dtype == trt.int32) and (
85+
rhs_val.dtype == trt.int8 or rhs_val.dtype == trt.int32
86+
):
87+
lhs_val = cast_trt_tensor(network, lhs_val, trt.float32)
88+
rhs_val = cast_trt_tensor(network, rhs_val, trt.float32)
89+
return list((lhs_val, rhs_val))
90+
91+
6892
def broadcastable(
6993
a: TRTTensor,
7094
b: TRTTensor,

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,12 @@ def addmm_replacement(
7070
)
7171

7272

73+
@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS)
74+
def reciprocal_replacement(
75+
input_: torch.Tensor,
76+
) -> torch.Tensor:
77+
return torch.div(1, input_)
78+
79+
7380
def get_decompositions():
7481
return DECOMPOSITIONS

tests/py/dynamo/backend/test_decompositions.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,14 @@ def forward(self, x):
7878
return y
7979

8080
# Operations expected to be removed in the traced graph after decompositions
81-
expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.reciprocal.default}
82-
unexpected_ops = {torch.ops.aten.rsqrt.default}
81+
expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.div.Tensor}
82+
unexpected_ops = {
83+
torch.ops.aten.rsqrt.default,
84+
torch.ops.aten.reciprocal.default,
85+
}
8386

8487
inputs = [
85-
torch.randint(
86-
1,
87-
10,
88-
(5,),
89-
),
88+
torch.randint(1, 10, (5,), dtype=torch.int32),
9089
]
9190

9291
fx_graph = torch.fx.symbolic_trace(Rsqrt())
@@ -182,6 +181,69 @@ def forward(self, x, y, z):
182181
f"AddMM TRT outputs don't match with the original model.",
183182
)
184183

184+
def test_lowering_reciprocal(self):
185+
class Reciprocal(torch.nn.Module):
186+
def __init__(self, *args, **kwargs) -> None:
187+
super().__init__(*args, **kwargs)
188+
189+
def forward(self, x):
190+
y = torch.ops.aten.reciprocal.default(x)
191+
return y
192+
193+
# Operations expected to be removed in the traced graph after decompositions
194+
expected_ops = {torch.ops.aten.div.Tensor}
195+
unexpected_ops = {torch.ops.aten.reciprocal.default}
196+
197+
inputs = [
198+
torch.randn(
199+
5,
200+
).cuda()
201+
]
202+
203+
fx_graph = torch.fx.symbolic_trace(Reciprocal())
204+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
205+
fx_graph,
206+
inputs,
207+
expected_ops=expected_ops,
208+
unexpected_ops=unexpected_ops,
209+
min_block_size=1,
210+
)
211+
212+
self.assertEquals(
213+
len(unexpected_ops_seen),
214+
0,
215+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
216+
)
217+
218+
self.assertEquals(
219+
len(expected_ops_unseen),
220+
0,
221+
f"The following expected ops were not encountered: {expected_ops_unseen}",
222+
)
223+
224+
torch._dynamo.reset()
225+
226+
# Validate that the results between Torch and Torch-TRT are similar
227+
optimized_model = torch_tensorrt.compile(
228+
fx_graph,
229+
"torch_tensorrt",
230+
inputs,
231+
min_block_size=1,
232+
pass_through_build_failures=True,
233+
)
234+
optimized_model_results = optimized_model(*inputs).detach().cpu()
235+
torch_model_results = fx_graph(*inputs).detach().cpu()
236+
237+
max_diff = float(
238+
torch.max(torch.abs(optimized_model_results - torch_model_results))
239+
)
240+
self.assertAlmostEqual(
241+
max_diff,
242+
0,
243+
DECIMALS_OF_AGREEMENT,
244+
f"Reciprocal TRT outputs don't match with the original model.",
245+
)
246+
185247

186248
if __name__ == "__main__":
187249
run_tests()

tests/py/dynamo/converters/test_binary_ops_aten.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,23 @@ def forward(self, x):
7575
inputs = [torch.rand(1, 1) + 1]
7676
self.run_test(m, inputs, expected_ops={expected_op})
7777

78+
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
79+
def test_elementwise_ops_mismatched_dtypes(
80+
self, name, orig_op: Callable, expected_op
81+
):
82+
class TestModule(nn.Module):
83+
def __init__(self, orig_op):
84+
super().__init__()
85+
self.orig_op = orig_op
86+
87+
def forward(self, x):
88+
return self.orig_op(x.int(), x)
89+
90+
m = TestModule(orig_op)
91+
# Avoid dividing by 0.
92+
inputs = [2 * torch.rand(1, 1, dtype=torch.float) + 1]
93+
self.run_test(m, inputs, expected_ops={expected_op})
94+
7895
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
7996
def test_elementwise_ops_with_one_constant(
8097
self, name, orig_op: Callable, expected_op
@@ -114,6 +131,41 @@ def forward(self, x):
114131
inputs = [torch.randn(2, 2)]
115132
self.run_test(m, inputs, expected_ops={expected_op})
116133

134+
@parameterized.expand([((lambda x, y: x / y), torch.ops.aten.div.Tensor)])
135+
def test_elementwise_op_div_with_two_ints(self, orig_op: Callable, expected_op):
136+
class TestModule(nn.Module):
137+
def __init__(self, orig_op):
138+
super().__init__()
139+
self.orig_op = orig_op
140+
141+
def forward(self, x):
142+
return self.orig_op(x, x + 1)
143+
144+
m = TestModule(orig_op)
145+
inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)]
146+
self.run_test(m, inputs, expected_ops={expected_op})
147+
148+
@parameterized.expand([((lambda x, y: x / y), torch.ops.aten.div.Tensor)])
149+
def test_elementwise_op_div_with_one_int_one_constant(
150+
self, orig_op: Callable, expected_op
151+
):
152+
class TestModule(nn.Module):
153+
def __init__(self, orig_op):
154+
super().__init__()
155+
self.constant1 = torch.nn.Parameter(
156+
torch.randn(
157+
5,
158+
)
159+
)
160+
self.orig_op = orig_op
161+
162+
def forward(self, x):
163+
return self.orig_op(x, self.constant1)
164+
165+
m = TestModule(orig_op)
166+
inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)]
167+
self.run_test(m, inputs, expected_ops={expected_op})
168+
117169
# Dynamic shape test
118170
@parameterized.expand(
119171
[

0 commit comments

Comments
 (0)