Skip to content

Commit e77d993

Browse files
apbosenarendasan
authored andcommitted
converter reorg for leaky relu
correcting nn_ops for leaky_relu and correcting linting error
1 parent 3fc3c6d commit e77d993

File tree

5 files changed

+126
-11
lines changed

5 files changed

+126
-11
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -1026,17 +1026,9 @@ def acc_ops_leaky_relu(
10261026
kwargs: Dict[str, Argument],
10271027
name: str,
10281028
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1029-
input_val = kwargs["input"]
1030-
negative_slope = kwargs["negative_slope"]
1031-
operation_type = trt.ActivationType.LEAKY_RELU
1032-
return activation.convert_activation(
1033-
network,
1034-
target,
1035-
SourceIR.ACC,
1036-
name,
1037-
operation_type,
1038-
input_val,
1039-
alpha=negative_slope,
1029+
1030+
return activation.leaky_relu(
1031+
network, target, SourceIR.ACC, name, kwargs["input"], kwargs["negative_slope"]
10401032
)
10411033

10421034

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+27
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,33 @@ def aten_ops_hardtanh(
215215
)
216216

217217

218+
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
219+
def aten_ops_fmod(
220+
network: TRTNetwork,
221+
target: Target,
222+
args: Tuple[Argument, ...],
223+
kwargs: Dict[str, Argument],
224+
name: str,
225+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
226+
kwargs_new = {
227+
"input": args[0],
228+
"other": args[1],
229+
}
230+
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
231+
232+
233+
@tensorrt_converter(torch.ops.aten.leaky_relu.default)
234+
def aten_ops_leaky_relu(
235+
network: TRTNetwork,
236+
target: Target,
237+
args: Tuple[Argument, ...],
238+
kwargs: Dict[str, Argument],
239+
name: str,
240+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
241+
242+
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
243+
244+
218245
@tensorrt_converter(torch.ops.aten.linear)
219246
def aten_ops_linear(
220247
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

+27
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,30 @@ def sigmoid_fn(x):
148148
input_val,
149149
dyn_range_fn=sigmoid_dyn_range_fn,
150150
)
151+
152+
153+
def leaky_relu(
154+
network: TRTNetwork,
155+
target: Target,
156+
source_ir: Optional[SourceIR],
157+
name: str,
158+
input_val: TRTTensor,
159+
alpha: Optional[Any],
160+
):
161+
operation_type = trt.ActivationType.LEAKY_RELU
162+
163+
def leaky_relu_dyn_range_fn(dyn_range):
164+
return (max(0, dyn_range[0]) + alpha * min(0, dyn_range[0])), (
165+
max(0, dyn_range[1]) + alpha * min(0, dyn_range[1])
166+
)
167+
168+
return convert_activation(
169+
network,
170+
target,
171+
source_ir,
172+
name,
173+
operation_type,
174+
input_val,
175+
alpha,
176+
dyn_range_fn=leaky_relu_dyn_range_fn,
177+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

+16
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,19 @@ def hardtanh(network, submod, args, kwargs, layer_name):
5151
name=layer_name,
5252
input_val=kwargs["input"],
5353
)
54+
55+
56+
@tensorrt_converter(torch.nn.functional.leaky_relu)
57+
@tensorrt_converter(torch.nn.modules.activation.LeakyReLU)
58+
def leaky_relu(network, submod, args, kwargs, layer_name):
59+
# args/kwargs should have already been normalized to kwargs
60+
assert len(args) == 0
61+
62+
return activation.leaky_relu(
63+
network=network,
64+
target="torch.nn.functional.leaky_relu",
65+
source_ir=SourceIR.NN,
66+
name=layer_name,
67+
input_val=kwargs["input"],
68+
alpha=kwargs["negative_slope"],
69+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestLeakyReLUConverter(DispatchTestCase):
8+
def test_leaky_relu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.leaky_relu(x, negative_slope=0.05)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default}
16+
)
17+
18+
def test_leaky_relu_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.leaky_relu(x, negative_slope=0.05)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default}
32+
)
33+
34+
def test_leaky_relu_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.leaky_relu(x, negative_slope=0.05)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)