From 06e83aa4f49e7ac7907a647e37c620be2b572000 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Tue, 26 Mar 2024 18:12:35 +0900 Subject: [PATCH 1/2] feat: support aten.expm1 converter --- .../dynamo/conversion/aten_ops_converters.py | 17 +++++ .../dynamo/conversion/impl/unary/ops.py | 29 ++++++++ tests/py/dynamo/conversion/test_expm1_aten.py | 69 +++++++++++++++++++ 3 files changed, 115 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_expm1_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0dd153d0aa..00daf34573 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1136,6 +1136,23 @@ def aten_ops_exp( ) +@dynamo_tensorrt_converter(torch.ops.aten.expm1.default) +def aten_ops_expm1( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.expm1( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.log.default) def aten_ops_log( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 554640ea5a..52e2bb6b3d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -44,6 +44,35 @@ def exp( ) +def expm1( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Computes e^x - 1 for each element of the input tensor. + + Args: + ctx (ConversionContext): TensorRT ConversionContext object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + TRTTensor: A TensorRT tensor represent the result of expm1 operator. + """ + # Compute e^x for each element of the input tensor + exp_result = exp(ctx, target, source_ir, f"{name}_exp", input_val) + + # # Subtract 1 from the result of the exponential operation + # expm1_result = sub(ctx, target, source_ir, f"{name}_sub", exp_result, 1) + + return impl.elementwise.sub(ctx, target, source_ir, f"{name}_div", exp_result, 1) + + def log( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_expm1_aten.py b/tests/py/dynamo/conversion/test_expm1_aten.py new file mode 100644 index 0000000000..e695a27475 --- /dev/null +++ b/tests/py/dynamo/conversion/test_expm1_aten.py @@ -0,0 +1,69 @@ +from math import exp + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestExpConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_expm1_float(self, input_shape, dtype): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [torch.randn(input_shape, dtype=dtype)] + self.run_test( + expm1(), + inputs, + ) + + @parameterized.expand( + [ + (torch.full((1, 20), exp(1), dtype=torch.float),), + (torch.full((2, 3, 4), exp(2), dtype=torch.float),), + (torch.full((2, 3, 4, 5), exp(3), dtype=torch.float),), + ] + ) + def test_expm1_exp_const_float(self, data): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [data] + self.run_test( + expm1(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_exp_int(self, input_shape, dtype, low, high): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + expm1(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From c7bf9be5e6f26e0d52685b7c2b74a5379cb3215c Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 1 Apr 2024 15:58:33 +0900 Subject: [PATCH 2/2] chore: minor fix --- py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 52e2bb6b3d..4b81de2bb5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -67,10 +67,7 @@ def expm1( # Compute e^x for each element of the input tensor exp_result = exp(ctx, target, source_ir, f"{name}_exp", input_val) - # # Subtract 1 from the result of the exponential operation - # expm1_result = sub(ctx, target, source_ir, f"{name}_sub", exp_result, 1) - - return impl.elementwise.sub(ctx, target, source_ir, f"{name}_div", exp_result, 1) + return impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", exp_result, 1) def log(