From aeb9ab98e61914c619ca98927c3d6c252ae66c75 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 18 Dec 2023 15:52:40 -0800 Subject: [PATCH 1/3] feat: support aten.copy dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 25 +++++++++++++++ tests/py/dynamo/conversion/test_copy_aten.py | 31 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_copy_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c25e331cf5..8acf7130d7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2547,3 +2547,28 @@ def aten_ops_trunc( name, args[0], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.copy.default) +@enforce_tensor_types( + { + 1: (TRTTensor,), + } +) +def aten_ops_copy( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + src = args[1] + return impl.cast.to_copy( + ctx, + target, + SourceIR.ATEN, + name, + src, + src.dtype, + force_layer=True, + ) diff --git a/tests/py/dynamo/conversion/test_copy_aten.py b/tests/py/dynamo/conversion/test_copy_aten.py new file mode 100644 index 0000000000..1acb94daf6 --- /dev/null +++ b/tests/py/dynamo/conversion/test_copy_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestCopyConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (3,), False), + ((1, 10), (1, 10), False), + ((2, 3, 4), (2, 3, 4), True), + ((2, 3, 4, 5), (2, 3, 4, 5), True), + ] + ) + def test_copy_float(self, input_shape, src_shape, non_blocking): + class Copy(nn.Module): + def forward(self, input, src): + return torch.ops.aten.copy.default(input, src, non_blocking) + + inputs = [torch.randn(input_shape), torch.randn(src_shape)] + self.run_test( + Copy(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 05e02c0fd2e46cd77cb78046eda3e2b4df7735b5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Dec 2023 08:14:06 -0800 Subject: [PATCH 2/3] fix bug --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8acf7130d7..a8a7710344 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2530,7 +2530,7 @@ def aten_ops_sort( @dynamo_tensorrt_converter(torch.ops.aten.trunc.default) @enforce_tensor_types( { - 0: (TRTTensor,), + 1: (TRTTensor,), } ) def aten_ops_trunc( From 2686ee1b1c95232104492639e97dba5e91058d1b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 29 Dec 2023 00:54:17 -0800 Subject: [PATCH 3/3] fix bug --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a8a7710344..8acf7130d7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2530,7 +2530,7 @@ def aten_ops_sort( @dynamo_tensorrt_converter(torch.ops.aten.trunc.default) @enforce_tensor_types( { - 1: (TRTTensor,), + 0: (TRTTensor,), } ) def aten_ops_trunc(