diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 361c5f4840..f270ce3ea8 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1026,17 +1026,9 @@ def acc_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - negative_slope = kwargs["negative_slope"] - operation_type = trt.ActivationType.LEAKY_RELU - return activation.convert_activation( - network, - target, - SourceIR.ACC, - name, - operation_type, - input_val, - alpha=negative_slope, + + return activation.leaky_relu( + network, target, SourceIR.ACC, name, kwargs["input"], kwargs["negative_slope"] ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index e7127d16d4..4f93d98a26 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -215,6 +215,33 @@ def aten_ops_hardtanh( ) +@tensorrt_converter(torch.ops.aten.fmod.Tensor) +def aten_ops_fmod( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.leaky_relu.default) +def aten_ops_leaky_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1]) + + @tensorrt_converter(torch.ops.aten.linear) def aten_ops_linear( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/impl/activation.py b/py/torch_tensorrt/fx/converters/impl/activation.py index 498850ec3e..793d0f90c9 100644 --- a/py/torch_tensorrt/fx/converters/impl/activation.py +++ b/py/torch_tensorrt/fx/converters/impl/activation.py @@ -175,3 +175,30 @@ def tanh_fn(x): input_val, dyn_range_fn=tanh_dyn_range_fn, ) + + +def leaky_relu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any], +): + operation_type = trt.ActivationType.LEAKY_RELU + + def leaky_relu_dyn_range_fn(dyn_range): + return (max(0, dyn_range[0]) + alpha * min(0, dyn_range[0])), ( + max(0, dyn_range[1]) + alpha * min(0, dyn_range[1]) + ) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha, + dyn_range_fn=leaky_relu_dyn_range_fn, + ) diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py index 4351ff5651..37b5de4115 100644 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -66,3 +66,19 @@ def tanh(network, submod, args, kwargs, layer_name): name=layer_name, input_val=kwargs["input"], ) + + +@tensorrt_converter(torch.nn.functional.leaky_relu) +@tensorrt_converter(torch.nn.modules.activation.LeakyReLU) +def leaky_relu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + + return activation.leaky_relu( + network=network, + target="torch.nn.functional.leaky_relu", + source_ir=SourceIR.NN, + name=layer_name, + input_val=kwargs["input"], + alpha=kwargs["negative_slope"], + ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py new file mode 100644 index 0000000000..7cdce77092 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLeakyReLUConverter(DispatchTestCase): + def test_leaky_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + +if __name__ == "__main__": + run_tests()