Skip to content

Commit 526b2e2

Browse files
apbosegs-olive
authored andcommitted
Rsqrt and linting error
1 parent 2de2db6 commit 526b2e2

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,11 @@ def aten_ops_rsqrt(
327327
kwargs: Dict[str, Argument],
328328
name: str,
329329
) -> Union[TRTTensor, Sequence[TRTTensor]]:
330-
330+
331331
return rsqrt(
332-
network,
333-
target,
334-
SourceIR.ATEN,
332+
network,
333+
target,
334+
SourceIR.ATEN,
335335
name,
336336
args[0],
337337
)

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,25 @@ def rsqrt(
117117
source_ir: Optional[SourceIR],
118118
name: str,
119119
input: TRTTensor,
120-
other: TRTTensor,
121120
) -> TRTTensor:
122-
121+
123122
sqrt_trt_output = convert_unary(
124123
network,
125124
target,
126125
source_ir,
127-
f"{name}"_sqrt,
126+
f"{name}_sqrt",
128127
trt.UnaryOperation.SQRT,
129128
input,
130129
)
131130

132131
output = convert_binary_elementwise(
133132
network,
133+
target,
134+
source_ir,
135+
f"{name}_output",
136+
trt.ElementWiseOperation.DIV,
134137
1,
135138
sqrt_trt_output,
136-
trt.ElementWiseOperation.DIV,
137-
target,
138-
f"{name}_outpur",
139139
)
140-
141-
return output
140+
141+
return output

py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
66

77

8-
class TestRSubConverter(DispatchTestCase):
8+
class TestRSqrtConverter(DispatchTestCase):
99
@parameterized.expand(
1010
[
1111
("2d_dim_alpha", (2, 1), 2),
@@ -26,4 +26,4 @@ def forward(self, input):
2626

2727

2828
if __name__ == "__main__":
29-
run_tests()
29+
run_tests()

0 commit comments

Comments
 (0)