Skip to content

Commit 0912d03

Browse files
chohk88laikhtewari
authored andcommitted
feat: support aten.atan2 converter (#2689)
1 parent b5618aa commit 0912d03

File tree

3 files changed

+334
-1
lines changed

3 files changed

+334
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,30 @@ def aten_ops_atanh(
13911391
)
13921392

13931393

1394+
@dynamo_tensorrt_converter(torch.ops.aten.atan2.default)
1395+
@enforce_tensor_types(
1396+
{
1397+
0: (TRTTensor,),
1398+
1: (TRTTensor,),
1399+
}
1400+
)
1401+
def aten_ops_atan2(
1402+
ctx: ConversionContext,
1403+
target: Target,
1404+
args: Tuple[Argument, ...],
1405+
kwargs: Dict[str, Argument],
1406+
name: str,
1407+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1408+
return impl.elementwise.atan2(
1409+
ctx,
1410+
target,
1411+
SourceIR.ATEN,
1412+
name,
1413+
args[0],
1414+
args[1],
1415+
)
1416+
1417+
13941418
@dynamo_tensorrt_converter(torch.ops.aten.ceil.default)
13951419
def aten_ops_ceil(
13961420
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Union
22

3+
import numpy as np
34
import tensorrt as trt
45
import torch
56
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -9,13 +10,15 @@
910
from torch_tensorrt.dynamo.conversion.converter_utils import (
1011
cast_int_int_div_trt_tensor,
1112
cast_int_or_float_to_bool,
13+
cast_trt_tensor,
1214
get_trt_tensor,
1315
)
1416
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1517
convert_binary_elementwise,
1618
)
17-
from torch_tensorrt.dynamo.conversion.impl.unary import sign
19+
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
1820
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
21+
from torch_tensorrt.fx.converters.converter_utils import broadcast
1922
from torch_tensorrt.fx.types import TRTTensor
2023
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2124

@@ -213,6 +216,180 @@ def remainder(
213216
return fmod2_value
214217

215218

219+
def atan2(
220+
ctx: ConversionContext,
221+
target: Target,
222+
source_ir: Optional[SourceIR],
223+
name: str,
224+
input: TRTTensor,
225+
other: TRTTensor,
226+
) -> TRTTensor:
227+
"""
228+
Perform atan2 operation on Tensor, calculating the arctangent of the quotient of input tensors.
229+
atan2(x,y) = atan(x/y) if y > 0,
230+
= atan(x/y) + π if x ≥ 0 and y < 0,
231+
= atan(x/y) - π if x < 0 and y < 0,
232+
= π/2 if x > 0 and y = 0,
233+
= -π/2 if x < 0 and y = 0,
234+
= 0 if x = 0 and y = 0
235+
236+
Args:
237+
ctx: ConversionContext.
238+
target: node target
239+
source_ir (SourceIR): Source IR calling the function.
240+
name: namespace for the op
241+
input: Tensor or constant representing the dividend.
242+
other: Tensor or constant representing the divisor.
243+
244+
Returns:
245+
A TensorRT tensor representing the result of the atan2 operation.
246+
"""
247+
pi_value = 3.141592653589793
248+
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi")
249+
250+
if isinstance(input, TRTTensor):
251+
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input")
252+
if isinstance(other, TRTTensor):
253+
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
254+
255+
input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other")
256+
257+
# Calculate x_zero, y_zero (whether inputs are zero)
258+
x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0)
259+
y_zero = eq(ctx, target, source_ir, f"{name}_y_zero", other, 0)
260+
261+
# Get sign of inputs
262+
x_positive = gt(ctx, target, source_ir, f"{name}_x_positive", input, 0)
263+
x_zero_positive = ge(ctx, target, source_ir, f"{name}_x_zero_positive", input, 0)
264+
x_negative = lt(ctx, target, source_ir, f"{name}_x_negative", input, 0)
265+
y_positive = gt(ctx, target, source_ir, f"{name}_y_positive", other, 0)
266+
y_negative = lt(ctx, target, source_ir, f"{name}_y_negative", other, 0)
267+
268+
# Calculate atan(x/y)
269+
input_div_other = div(
270+
ctx, target, source_ir, f"{name}_input_div_other", input, other
271+
)
272+
atan_val = atan(ctx, target, source_ir, f"{name}_atan", input_div_other)
273+
274+
# atan(x/y)+π if x≥0 and y<0,
275+
atan_add_pi = add(
276+
ctx, target, source_ir, f"{name}_atan_add_pi", atan_val, pi_tensor
277+
)
278+
279+
# atan(x/y)-π if x<0 and y<0,
280+
atan_sub_pi = sub(
281+
ctx, target, source_ir, f"{name}_atan_sub_pi", atan_val, pi_tensor
282+
)
283+
284+
# atan(x/y)+π if x≥0 and y<0,
285+
atan_corrected = impl.condition.select(
286+
ctx,
287+
target,
288+
source_ir,
289+
f"{name}_atan_corrected",
290+
atan_add_pi,
291+
atan_val,
292+
logical_and(
293+
ctx,
294+
target,
295+
source_ir,
296+
f"{name}_x_zero_positive_and_y_negative",
297+
x_zero_positive,
298+
y_negative,
299+
),
300+
)
301+
302+
# atan(x/y)-π if x<0 and y<0,
303+
atan_corrected_2 = impl.condition.select(
304+
ctx,
305+
target,
306+
source_ir,
307+
f"{name}_atan_corrected_2",
308+
atan_sub_pi,
309+
atan_corrected,
310+
logical_and(
311+
ctx,
312+
target,
313+
source_ir,
314+
f"{name}_x_negative_and_y_negative",
315+
x_negative,
316+
y_negative,
317+
),
318+
)
319+
320+
# atan(x/y) if y>0
321+
atan_output = impl.condition.select(
322+
ctx,
323+
target,
324+
source_ir,
325+
f"{name}_atan_output",
326+
atan_val,
327+
atan_corrected_2,
328+
y_positive,
329+
)
330+
331+
# on x or y-axis
332+
pi_over_2_tensor = get_trt_tensor(
333+
ctx,
334+
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
335+
f"{name}_pi_over_2_tensor",
336+
dtype=trt.float32,
337+
)
338+
minus_pi_over_2_tensor = get_trt_tensor(
339+
ctx,
340+
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
341+
f"{name}_minus_pi_over_2_tensor",
342+
dtype=trt.float32,
343+
)
344+
zero_tensor = get_trt_tensor(
345+
ctx,
346+
np.zeros(input.shape, dtype=np.float32),
347+
f"{name}_zero_tensor",
348+
dtype=trt.float32,
349+
)
350+
351+
# π/2 if x>0 and y=0,
352+
pi_over_2_output = impl.condition.select(
353+
ctx,
354+
target,
355+
source_ir,
356+
f"{name}_pi_over_2_output",
357+
pi_over_2_tensor,
358+
atan_output,
359+
logical_and(
360+
ctx, target, source_ir, f"{name}_x_zero_and_y_positive", x_positive, y_zero
361+
),
362+
)
363+
364+
# -π/2 if x<0 and y=0,
365+
minus_pi_over_2_output = impl.condition.select(
366+
ctx,
367+
target,
368+
source_ir,
369+
f"{name}_minus_pi_over_2_output",
370+
minus_pi_over_2_tensor,
371+
pi_over_2_output,
372+
logical_and(
373+
ctx, target, source_ir, f"{name}_x_zero_and_y_negative", x_negative, y_zero
374+
),
375+
)
376+
377+
# 0 if x=0 and y=0,
378+
zero_output = impl.condition.select(
379+
ctx,
380+
target,
381+
source_ir,
382+
f"{name}_zero_output",
383+
zero_tensor,
384+
minus_pi_over_2_output,
385+
logical_and(
386+
ctx, target, source_ir, f"{name}_x_zero_and_y_zero", y_zero, x_zero
387+
),
388+
)
389+
390+
return zero_output
391+
392+
216393
def clamp(
217394
ctx: ConversionContext,
218395
target: Target,
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAtan2Converter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), torch.float),
13+
((1, 20), torch.float),
14+
((2, 3, 4), torch.float),
15+
((2, 3, 4, 5), torch.float),
16+
]
17+
)
18+
def test_atan2_lhs_const(self, input_shape, dtype):
19+
class atan2(nn.Module):
20+
def forward(self, lhs_val, rhs_val):
21+
return torch.ops.aten.atan2.default(lhs_val, rhs_val)
22+
23+
inputs = [
24+
torch.randn(input_shape, dtype=dtype),
25+
torch.rand(1),
26+
]
27+
28+
self.run_test(
29+
atan2(),
30+
inputs,
31+
)
32+
33+
@parameterized.expand(
34+
[
35+
((10,), torch.float),
36+
((1, 20), torch.float),
37+
((2, 3, 4), torch.float),
38+
((2, 3, 4, 5), torch.float),
39+
]
40+
)
41+
def test_atan2_rhs_const(self, input_shape, dtype):
42+
class atan2(nn.Module):
43+
def forward(self, lhs_val, rhs_val):
44+
return torch.ops.aten.atan2.default(lhs_val, rhs_val)
45+
46+
inputs = [
47+
torch.rand(1),
48+
torch.randn(input_shape, dtype=dtype),
49+
]
50+
51+
self.run_test(
52+
atan2(),
53+
inputs,
54+
)
55+
56+
@parameterized.expand(
57+
[
58+
((10,), torch.float),
59+
((1, 20), torch.float),
60+
((2, 3, 4), torch.float),
61+
((2, 3, 4, 5), torch.float),
62+
]
63+
)
64+
def test_atan2_float(self, input_shape, dtype):
65+
class atan2(nn.Module):
66+
def forward(self, lhs_val, rhs_val):
67+
return torch.ops.aten.atan2.default(lhs_val, rhs_val)
68+
69+
inputs = [
70+
torch.randn(input_shape, dtype=dtype),
71+
torch.randn(input_shape, dtype=dtype),
72+
]
73+
74+
self.run_test(
75+
atan2(),
76+
inputs,
77+
)
78+
79+
@parameterized.expand(
80+
[
81+
((50,), torch.int, -5, 5),
82+
((1, 20), torch.int32, -5, 5),
83+
((2, 3, 4), torch.int, -5, 5),
84+
]
85+
)
86+
def test_atan2_int(self, input_shape, dtype, low, high):
87+
class atan2(nn.Module):
88+
def forward(self, lhs_val, rhs_val):
89+
return torch.ops.aten.atan2.default(lhs_val, rhs_val)
90+
91+
inputs = [
92+
torch.randint(low, high, input_shape, dtype=dtype),
93+
torch.randint(low, high, input_shape, dtype=dtype),
94+
]
95+
self.run_test(
96+
atan2(),
97+
inputs,
98+
)
99+
100+
@parameterized.expand(
101+
[
102+
(torch.float, 0.0, 0.0),
103+
(torch.float, 0.0, torch.rand(1)),
104+
(torch.float, torch.rand(1), 0.0),
105+
(torch.int, 0, 0),
106+
(torch.int, 0, torch.randint(-5, 5, (1,))),
107+
(torch.int, torch.randint(1, 10, (1,)), 0),
108+
]
109+
)
110+
def test_atan2_zero(self, dtype, x_val, y_val):
111+
class Atan2(nn.Module):
112+
def forward(self, lhs_val, rhs_val):
113+
return torch.ops.aten.atan2.default(lhs_val, rhs_val)
114+
115+
if isinstance(x_val, torch.Tensor):
116+
x_val = x_val.item()
117+
if isinstance(y_val, torch.Tensor):
118+
y_val = y_val.item()
119+
120+
inputs = [
121+
torch.tensor([x_val], dtype=dtype),
122+
torch.tensor([y_val], dtype=dtype),
123+
]
124+
125+
self.run_test(
126+
Atan2(),
127+
inputs,
128+
)
129+
130+
131+
if __name__ == "__main__":
132+
run_tests()

0 commit comments

Comments
 (0)