diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 7f472261db..829ec59e95 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2278,6 +2278,29 @@ def aten_ops_reshape( ) +@dynamo_tensorrt_converter(torch.ops.aten.pixel_shuffle.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pixel_shuffle( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.pixel_shuffle( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @enforce_tensor_types({0: (TRTTensor,)}) @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 3a4c160d77..49ddb76e2c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -1,5 +1,6 @@ from typing import Optional, Sequence, Union +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR @@ -19,3 +20,43 @@ def reshape( layer.reshape_dims = tuple(shape) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def pixel_shuffle( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + upscale_factor: int, +) -> TRTTensor: + shape = input.shape + in_channels, in_height, in_width = shape[-3:] + out_channels = in_channels // (upscale_factor**2) + out_height = in_height * upscale_factor + out_width = in_width * upscale_factor + new_shape = shape[:-3] + ( + out_channels, + upscale_factor, + upscale_factor, + in_height, + in_width, + ) + reshaped_tensor = reshape( + ctx, target, source_ir, f"{name}_reshape1", input, new_shape + ) + rank = len(shape) + permute_shape = list(range(rank)) + permute_shape.insert(-2, rank) + permute_shape.insert(-1, rank + 1) + permuted_tensor = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape + ) + return reshape( + ctx, + target, + source_ir, + f"{name}_reshape2", + permuted_tensor, + shape[:-3] + (out_channels, out_height, out_width), + ) diff --git a/tests/py/dynamo/conversion/test_pixel_shuffle_aten.py b/tests/py/dynamo/conversion/test_pixel_shuffle_aten.py new file mode 100644 index 0000000000..a58212f894 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pixel_shuffle_aten.py @@ -0,0 +1,31 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestPixelShuffleConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 1, 1), 1), + ((12, 3, 4), 2), + ((1, 9, 4, 4), 3), + ((2, 32, 2, 3), 4), + ((1, 10, 36, 2, 4), 6), + ] + ) + def test_pixel_shuffle(self, shape, upscale_factor): + class PixelShuffle(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.pixel_shuffle.default(x, upscale_factor) + + inputs = [torch.randn(shape)] + self.run_test( + PixelShuffle(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()