Skip to content

Commit c8a9559

Browse files
committed
Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) <Target: converter_reorg_elementwise> (#1905)
1 parent d23cdcd commit c8a9559

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2525
from torch_tensorrt.fx.converters.impl import activation
2626
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
27+
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
2728

2829
_LOGGER: logging.Logger = logging.getLogger(__name__)
2930

@@ -368,6 +369,42 @@ def aten_ops_relu(
368369
)
369370

370371

372+
@tensorrt_converter(torch.ops.aten.relu.default)
373+
def aten_ops_relu(
374+
network: TRTNetwork,
375+
target: Target,
376+
args: Tuple[Argument, ...],
377+
kwargs: Dict[str, Argument],
378+
name: str,
379+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
380+
381+
return activation.relu(
382+
network,
383+
target,
384+
SourceIR.ATEN,
385+
name,
386+
args[0],
387+
)
388+
389+
390+
@tensorrt_converter(torch.ops.aten.rsqrt.default)
391+
def aten_ops_rsqrt(
392+
network: TRTNetwork,
393+
target: Target,
394+
args: Tuple[Argument, ...],
395+
kwargs: Dict[str, Argument],
396+
name: str,
397+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
398+
399+
return rsqrt(
400+
network,
401+
target,
402+
SourceIR.ATEN,
403+
name,
404+
args[0],
405+
)
406+
407+
371408
@tensorrt_converter(torch.ops.aten.sub.Tensor)
372409
def aten_ops_sub(
373410
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,33 @@ def trunc_div(
109109
)
110110

111111
return output
112+
113+
114+
def rsqrt(
115+
network: TRTNetwork,
116+
target: Target,
117+
source_ir: Optional[SourceIR],
118+
name: str,
119+
input: TRTTensor,
120+
) -> TRTTensor:
121+
122+
sqrt_trt_output = convert_unary(
123+
network,
124+
target,
125+
source_ir,
126+
f"{name}_sqrt",
127+
trt.UnaryOperation.SQRT,
128+
input,
129+
)
130+
131+
output = convert_binary_elementwise(
132+
network,
133+
target,
134+
source_ir,
135+
f"{name}_output",
136+
trt.ElementWiseOperation.DIV,
137+
1,
138+
sqrt_trt_output,
139+
)
140+
141+
return output
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
6+
7+
8+
class TestRSqrtConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
("2d_dim_alpha", (2, 1), 2),
12+
("3d_dim_alpha", (2, 1, 2), 2),
13+
]
14+
)
15+
def test_rsqrt(self, _, x, alpha):
16+
class rsqrt(nn.Module):
17+
def forward(self, input):
18+
return torch.rsqrt(input)
19+
20+
inputs = [torch.randn(x) + 1]
21+
self.run_test(
22+
rsqrt(),
23+
inputs,
24+
expected_ops={torch.ops.aten.rsqrt.default},
25+
)
26+
27+
28+
if __name__ == "__main__":
29+
run_tests()

0 commit comments

Comments
 (0)