diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 478cf98dea..92aa212cad 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2708,6 +2708,23 @@ def aten_ops_scalar_tensor( ) +@dynamo_tensorrt_converter(torch.ops.aten.log10.default) +def log10( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.log10( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.roll.default) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index fc6c737e79..784a21009d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -61,6 +61,22 @@ def log( ) +def log10( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + log_layer_output = log(ctx, target, source_ir, f"{name}_log", input_val) + + ln10 = 2.302585092994046 + + return impl.elementwise.div( + ctx, target, source_ir, f"{name}_div", log_layer_output, ln10 + ) + + def sqrt( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_log10.py b/tests/py/dynamo/conversion/test_log10.py new file mode 100644 index 0000000000..9094f6b278 --- /dev/null +++ b/tests/py/dynamo/conversion/test_log10.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestLogConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_log10_float(self, input_shape, dtype): + class log10(nn.Module): + def forward(self, input): + return torch.ops.aten.log10.default(input) + + inputs = [torch.randn(input_shape, dtype=dtype)] + self.run_test( + log10(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_log10_int(self, input_shape, dtype, low, high): + class log10(nn.Module): + def forward(self, input): + return torch.ops.aten.log10.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + log10(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()