diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b05713c360..459c69bfa5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -330,6 +330,34 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (TRTTensor,), + } +) +def aten_ops_grid( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.grid.grid( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + grid=args[1], + interpolation_mode=args[2], + padding_mode=args[3], + align_corners=args[4], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.relu.default) def aten_ops_relu( ctx: ConversionContext, @@ -754,12 +782,12 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_tile( ctx: ConversionContext, target: Target, @@ -777,7 +805,7 @@ def aten_ops_tile( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { 0: (TRTTensor,), @@ -1990,14 +2018,14 @@ def aten_ops_argmax( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) @enforce_tensor_types( { 0: (TRTTensor,), 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) # type: ignore[misc] +) def aten_ops_addmm( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index ab0c29e7d5..b448b40bc3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -12,6 +12,7 @@ deconv, elementwise, embedding, + grid, linear, matmul, normalization, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py new file mode 100644 index 0000000000..672fc97351 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -0,0 +1,47 @@ +from typing import Optional, Sequence + +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +# nearest, linear, cubic +GridSamplerInterpolationMode = { + 0: trt.InterpolationMode.NEAREST, + 1: trt.InterpolationMode.LINEAR, + 2: trt.InterpolationMode.CUBIC, +} + +# zeros, border, reflection +GridSamplerSampling = { + 0: trt.SampleMode.FILL, + 1: trt.SampleMode.CLAMP, + 2: trt.SampleMode.REFLECT, +} + + +def grid( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + grid: TRTTensor, + interpolation_mode: int, + padding_mode: int, + align_corners: bool, +) -> TRTTensor: + grid_layer = ctx.net.add_grid_sample(input, grid) + assert interpolation_mode in GridSamplerInterpolationMode + grid_layer.interpolation_mode = GridSamplerInterpolationMode.get( + interpolation_mode, None + ) + assert padding_mode in GridSamplerSampling + grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None) + grid_layer.align_corners = align_corners + set_layer_name(grid_layer, target, name + "_grid_layer", source_ir) + return grid_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py new file mode 100644 index 0000000000..32480110f3 --- /dev/null +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -0,0 +1,149 @@ +import pytest +import torch +import torch.nn as nn +from .harness import DispatchTestCase +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +grid_sampler_ops = [ + ( + "input_grid_interpolation_nearest_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), +] + + +class TestGridConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1], + grid_sampler_op[2], + grid_sampler_op[3], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid(self, _, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + +if __name__ == "__main__": + run_tests()