diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c25e331cf5..8acf7130d7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2547,3 +2547,28 @@ def aten_ops_trunc( name, args[0], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.copy.default) +@enforce_tensor_types( + { + 1: (TRTTensor,), + } +) +def aten_ops_copy( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + src = args[1] + return impl.cast.to_copy( + ctx, + target, + SourceIR.ATEN, + name, + src, + src.dtype, + force_layer=True, + ) diff --git a/tests/py/dynamo/conversion/test_copy_aten.py b/tests/py/dynamo/conversion/test_copy_aten.py new file mode 100644 index 0000000000..1acb94daf6 --- /dev/null +++ b/tests/py/dynamo/conversion/test_copy_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestCopyConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (3,), False), + ((1, 10), (1, 10), False), + ((2, 3, 4), (2, 3, 4), True), + ((2, 3, 4, 5), (2, 3, 4, 5), True), + ] + ) + def test_copy_float(self, input_shape, src_shape, non_blocking): + class Copy(nn.Module): + def forward(self, input, src): + return torch.ops.aten.copy.default(input, src, non_blocking) + + inputs = [torch.randn(input_shape), torch.randn(src_shape)] + self.run_test( + Copy(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()