Skip to content

Commit 8ae9eff

Browse files
authored
feat: support aten.remainder.Scalar and aten.remainder.Tensor (#2566)
1 parent 2df45cf commit 8ae9eff

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,3 +2572,27 @@ def aten_ops_copy(
25722572
src.dtype,
25732573
force_layer=True,
25742574
)
2575+
2576+
2577+
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Scalar)
2578+
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Tensor)
2579+
@enforce_tensor_types(
2580+
{
2581+
0: (TRTTensor,),
2582+
}
2583+
)
2584+
def aten_ops_remainder(
2585+
ctx: ConversionContext,
2586+
target: Target,
2587+
args: Tuple[Argument, ...],
2588+
kwargs: Dict[str, Argument],
2589+
name: str,
2590+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2591+
return impl.elementwise.remainder(
2592+
ctx,
2593+
target,
2594+
SourceIR.ATEN,
2595+
name,
2596+
args[0],
2597+
args[1],
2598+
)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,41 @@ def fmod(
178178
return sub_value
179179

180180

181+
def remainder(
182+
ctx: ConversionContext,
183+
target: Target,
184+
source_ir: Optional[SourceIR],
185+
name: str,
186+
input: TRTTensor,
187+
other: TRTTensor,
188+
) -> TRTTensor:
189+
fmod1_value = fmod(
190+
ctx,
191+
target,
192+
source_ir,
193+
f"{name}_fmod1",
194+
input,
195+
other,
196+
)
197+
added_value = add(
198+
ctx,
199+
target,
200+
source_ir,
201+
f"{name}_add",
202+
fmod1_value,
203+
other,
204+
)
205+
fmod2_value = fmod(
206+
ctx,
207+
target,
208+
source_ir,
209+
f"{name}_fmod2",
210+
added_value,
211+
other,
212+
)
213+
return fmod2_value
214+
215+
181216
def clamp(
182217
ctx: ConversionContext,
183218
target: Target,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestRemainderConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("1d", (5,), 3),
14+
("2d", (2, 1), 1.0),
15+
("3d", (2, 1, 2), 2),
16+
]
17+
)
18+
def test_remainder_scalar(self, _, shape, scalar):
19+
class Remainder(nn.Module):
20+
def forward(self, lhs_val):
21+
return torch.ops.aten.remainder.Scalar(lhs_val, scalar)
22+
23+
inputs = [torch.randn(shape)]
24+
self.run_test(
25+
Remainder(),
26+
inputs,
27+
)
28+
29+
def test_remainder_scalar_int(self, scalar=3):
30+
class Remainder(nn.Module):
31+
def forward(self, lhs_val):
32+
return torch.ops.aten.remainder.Scalar(lhs_val, scalar)
33+
34+
inputs = [torch.tensor([0, 1, 2, 3, 4, -1, -2, -3, -4], dtype=torch.float32)]
35+
self.run_test(
36+
Remainder(),
37+
inputs,
38+
)
39+
40+
@parameterized.expand(
41+
[
42+
("1d", (5,)),
43+
("2d", (2, 1)),
44+
("3d", (2, 1, 2)),
45+
]
46+
)
47+
def test_remainder_tensor(self, _, shape):
48+
class Remainder(nn.Module):
49+
def forward(self, lhs_val, rhs_val):
50+
return torch.ops.aten.remainder.Tensor(lhs_val, rhs_val)
51+
52+
inputs = [torch.randn(shape), torch.randn(shape)]
53+
self.run_test(
54+
Remainder(),
55+
inputs,
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
run_tests()

0 commit comments

Comments
 (0)