From e39b69e54a39b5a7090d8a7dcc92e4869bf16065 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 5 Sep 2023 15:15:09 -0700 Subject: [PATCH 01/10] Expose IGridSampleLayer --- .../dynamo/conversion/aten_ops_converters.py | 16 ++++++++ .../dynamo/conversion/converter_utils.py | 26 +++++++++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/grid.py | 26 +++++++++++++ tests/py/dynamo/conversion/test_grid_aten.py | 38 +++++++++++++++++++ 5 files changed, 107 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/grid.py create mode 100644 tests/py/dynamo/conversion/test_grid_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b05713c360..e0f8edc4de 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -330,6 +330,22 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out) +def aten_ops_grid( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4]) + + @dynamo_tensorrt_converter(torch.ops.aten.relu.default) def aten_ops_relu( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index b65f95f0e5..367683f5a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -24,6 +24,32 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +#nearesr, linear, cubc +class GridSamplerInterpolation: + def __init__(self): + self.interpolator_mode = None + def __call__(self, interpolator_int): + if(interpolator_int == 0) : + self.interpolator_mode = trt.InterpolationMode.NEAREST + elif(interpolator_int == 1) : + self.interpolator_mode = trt.InterpolationMode.LINEAR + elif(interpolator_int == 2) : + self.interpolator_mode = trt.InterpolationMode.CUBIC + return self.interpolator_mode + + +#zeros, border, reflection +class GridSamplerPadding: + def __init__(self): + self.padding_mode = None + def __call__(self, padding_int): + if(padding_int == 0) : + self.padding_mode = trt.SampleMode.kFILL + elif(padding_int == 1) : + self.padding_mode = trt.SampleMode.kCLAMP + elif(padding_int == 2) : + self.padding_mode = trt.SampleMode.kREFLECT + return self.padding_mode def get_node_name(node: torch.fx.Node) -> str: # nn_module_stack preserves the call stack of pytorch nn.modules diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index ab0c29e7d5..b448b40bc3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -12,6 +12,7 @@ deconv, elementwise, embedding, + grid, linear, matmul, normalization, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py new file mode 100644 index 0000000000..3a28ae4e71 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -0,0 +1,26 @@ +from typing import Optional + +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +def grid( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + grid: TRTTensor, + interpolation_mode: int, + padding_mode: int, + align_corners: bool, +) -> TRTTensor: + grid_layer = network.add_grid_sample(input, grid) + grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode) + grid_layer.padding_mode = GridSamplerPadding(padding_mode) + grid_layer.align_corners = align_corners + set_layer_name(grid_layer, target, name + "_grid_layer", source_ir) + return grid_layer.get_output(0) \ No newline at end of file diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py new file mode 100644 index 0000000000..048450d5b4 --- /dev/null +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -0,0 +1,38 @@ +import pytest +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input +from parameterized import parameterized +from .harness import DispatchTestCase + +class TestGridConverter(DispatchTestCase): + @parameterized.expand( + [ + ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0), + ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1), + ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2), + ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0), + ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1), + ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2), + ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0), + ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1), + ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2), + ] + ) + def test_grid(self,_, input_shape, dim_shape, interpolation, sample): + class TestModule(nn.Module): + def forward(self, x): + input = torch.randn(10).reshape(input_shape) + grid = torch.randint(-1, 1, dim_shape) + return nn.functional.grid(input, grid, interpolation, sample) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}) + + + + + + + \ No newline at end of file From 13319d802d2244e000892df14800059d88ba1985 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 12 Oct 2023 17:21:30 -0700 Subject: [PATCH 02/10] Grid test changes --- .../dynamo/conversion/aten_ops_converters.py | 23 ++++- .../dynamo/conversion/converter_utils.py | 38 ++++---- .../dynamo/conversion/impl/grid.py | 32 +++++-- tests/py/dynamo/conversion/test_grid_aten.py | 95 ++++++++++++++----- 4 files changed, 139 insertions(+), 49 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e0f8edc4de..5103b35921 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -330,12 +330,17 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out) @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out) @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out) @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_grid( ctx: ConversionContext, target: Target, @@ -343,7 +348,19 @@ def aten_ops_grid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4]) + return impl.grid.grid( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + grid=args[1], + interpolation_mode=args[2], + padding_mode=args[3], + align_corners=args_bounds_check(args, 4, True), + output_mask=args_bounds_check(args, 5, None), + + ) @dynamo_tensorrt_converter(torch.ops.aten.relu.default) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 367683f5a5..9fc981c10a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -24,32 +24,36 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -#nearesr, linear, cubc + +# nearest, linear, cubic class GridSamplerInterpolation: def __init__(self): self.interpolator_mode = None - def __call__(self, interpolator_int): - if(interpolator_int == 0) : + + def __call__(self, interpolator_int): + if interpolator_int == 0: self.interpolator_mode = trt.InterpolationMode.NEAREST - elif(interpolator_int == 1) : + elif interpolator_int == 1: self.interpolator_mode = trt.InterpolationMode.LINEAR - elif(interpolator_int == 2) : + elif interpolator_int == 2: self.interpolator_mode = trt.InterpolationMode.CUBIC return self.interpolator_mode - -#zeros, border, reflection -class GridSamplerPadding: + +# zeros, border, reflection +class GridSamplerSampling: def __init__(self): - self.padding_mode = None - def __call__(self, padding_int): - if(padding_int == 0) : - self.padding_mode = trt.SampleMode.kFILL - elif(padding_int == 1) : - self.padding_mode = trt.SampleMode.kCLAMP - elif(padding_int == 2) : - self.padding_mode = trt.SampleMode.kREFLECT - return self.padding_mode + self.sample_mode = None + + def __call__(self, sample_int): + if sample_int == 0: + self.sample_mode = trt.SampleMode.FILL + elif sample_int == 1: + self.sample_mode = trt.SampleMode.CLAMP + elif sample_int == 2: + self.sample_mode = trt.SampleMode.REFLECT + return self.sample_mode + def get_node_name(node: torch.fx.Node) -> str: # nn_module_stack preserves the call stack of pytorch nn.modules diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py index 3a28ae4e71..bee99293e4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -1,14 +1,21 @@ -from typing import Optional +from typing import Optional, Sequence +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + GridSamplerInterpolation, + GridSamplerSampling, + cast_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + def grid( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, @@ -17,10 +24,21 @@ def grid( interpolation_mode: int, padding_mode: int, align_corners: bool, + output_mask: Optional[Sequence[bool]] = None, ) -> TRTTensor: - grid_layer = network.add_grid_sample(input, grid) - grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode) - grid_layer.padding_mode = GridSamplerPadding(padding_mode) + grid_layer = ctx.net.add_grid_sample(input, grid) + interpolation_mode_trt = GridSamplerInterpolation() + grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode) + sample_mode_trt = GridSamplerSampling() + grid_layer.sample_mode = sample_mode_trt(padding_mode) grid_layer.align_corners = align_corners set_layer_name(grid_layer, target, name + "_grid_layer", source_ir) - return grid_layer.get_output(0) \ No newline at end of file + if output_mask is None: + return grid_layer.get_output(0) + else: + if output_mask[0] and output_mask[1]: + return (grid_layer.get_output(0), None) + elif output_mask[0]: + return grid_layer.get_output(0) + else: + return None diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py index 048450d5b4..5ac615c78c 100644 --- a/tests/py/dynamo/conversion/test_grid_aten.py +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -1,38 +1,89 @@ import pytest import torch import torch.nn as nn +from .harness import DispatchTestCase +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from parameterized import parameterized -from .harness import DispatchTestCase + class TestGridConverter(DispatchTestCase): @parameterized.expand( [ - ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0), - ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1), - ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2), - ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0), - ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1), - ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2), - ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0), - ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1), - ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2), + ( + "input_grid_interpolation_nearest_sample_fill", + [1, 1, 5, 5], + [1, 5, 2, 2], + 0, + 0, + ), + ( + "input_grid_interpolation_nearest_sample_clamp", + [1, 1, 5, 5], + [1, 5, 2, 2], + 0, + 1, + ), + ( + "input_grid_interpolation_nearest_sample_reflect", + [1, 1, 5, 5], + [1, 5, 2, 2], + 0, + 2, + ), + ( + "input_grid_interpolation_linear_sample_fill", + [1, 1, 5, 5], + [1, 5, 2, 2], + 1, + 0, + ), + ( + "input_grid_interpolation_linear_sample_clamp", + [1, 1, 5, 5], + [1, 5, 2, 2], + 1, + 1, + ), + ( + "input_grid_interpolation_linear_sample_reflect", + [1, 1, 5, 5], + [1, 5, 2, 2], + 1, + 2, + ), + ( + "input_grid_interpolation_cubic_sample_fill", + [1, 1, 5, 5], + [1, 5, 2, 2], + 2, + 0, + ), + ( + "input_grid_interpolation_cubic_sample_clamp", + [1, 1, 5, 5], + [1, 5, 2, 2], + 2, + 1, + ), + ( + "input_grid_interpolation_cubic_sample_reflect", + [1, 1, 5, 5], + [1, 5, 2, 2], + 2, + 2, + ), ] ) - def test_grid(self,_, input_shape, dim_shape, interpolation, sample): + def test_grid(self, _, input_shape, dim_shape, interpolation, sample): class TestModule(nn.Module): def forward(self, x): - input = torch.randn(10).reshape(input_shape) - grid = torch.randint(-1, 1, dim_shape) - return nn.functional.grid(input, grid, interpolation, sample) - - inputs = [torch.randn(1, 10)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}) - - + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True) + inputs = [torch.randn(input_shape, dtype=torch.float32)] + self.run_test(TestModule(), inputs) - - \ No newline at end of file +if __name__ == "__main__": + run_tests() From 5c6905de5dc43c1ee4a447fe4806b4c13c685f85 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 19 Oct 2023 17:40:33 -0700 Subject: [PATCH 03/10] Addressing review comments --- .../dynamo/conversion/converter_utils.py | 30 -------------- .../dynamo/conversion/impl/grid.py | 41 ++++++++++++------- 2 files changed, 26 insertions(+), 45 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 9fc981c10a..b65f95f0e5 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -25,36 +25,6 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -# nearest, linear, cubic -class GridSamplerInterpolation: - def __init__(self): - self.interpolator_mode = None - - def __call__(self, interpolator_int): - if interpolator_int == 0: - self.interpolator_mode = trt.InterpolationMode.NEAREST - elif interpolator_int == 1: - self.interpolator_mode = trt.InterpolationMode.LINEAR - elif interpolator_int == 2: - self.interpolator_mode = trt.InterpolationMode.CUBIC - return self.interpolator_mode - - -# zeros, border, reflection -class GridSamplerSampling: - def __init__(self): - self.sample_mode = None - - def __call__(self, sample_int): - if sample_int == 0: - self.sample_mode = trt.SampleMode.FILL - elif sample_int == 1: - self.sample_mode = trt.SampleMode.CLAMP - elif sample_int == 2: - self.sample_mode = trt.SampleMode.REFLECT - return self.sample_mode - - def get_node_name(node: torch.fx.Node) -> str: # nn_module_stack preserves the call stack of pytorch nn.modules # The call stack contains a detailed name of the module diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py index bee99293e4..af4e1ecabc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -5,14 +5,24 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - GridSamplerInterpolation, - GridSamplerSampling, - cast_trt_tensor, -) +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +# nearest, linear, cubic +GridSamplerInterpolationMode = { + 0: trt.InterpolationMode.NEAREST, + 1: trt.InterpolationMode.LINEAR, + 2: trt.InterpolationMode.CUBIC, +} + +# zeros, border, reflection +GridSamplerSampling = { + 0: trt.SampleMode.FILL, + 1: trt.SampleMode.CLAMP, + 2: trt.SampleMode.REFLECT, +} + def grid( ctx: ConversionContext, @@ -27,18 +37,19 @@ def grid( output_mask: Optional[Sequence[bool]] = None, ) -> TRTTensor: grid_layer = ctx.net.add_grid_sample(input, grid) - interpolation_mode_trt = GridSamplerInterpolation() - grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode) - sample_mode_trt = GridSamplerSampling() - grid_layer.sample_mode = sample_mode_trt(padding_mode) + assert interpolation_mode in GridSamplerInterpolationMode + grid_layer.interpolation_mode = GridSamplerInterpolationMode.get( + interpolation_mode, None + ) + assert padding_mode in GridSamplerSampling + grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None) grid_layer.align_corners = align_corners set_layer_name(grid_layer, target, name + "_grid_layer", source_ir) if output_mask is None: return grid_layer.get_output(0) + elif output_mask[0] and output_mask[1]: + return (grid_layer.get_output(0), None) + elif output_mask[0]: + return grid_layer.get_output(0) else: - if output_mask[0] and output_mask[1]: - return (grid_layer.get_output(0), None) - elif output_mask[0]: - return grid_layer.get_output(0) - else: - return None + return None From b4619f05cbf775f24429fd8cff410d03bb7ca973 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 20 Oct 2023 15:37:09 -0700 Subject: [PATCH 04/10] Keeping the ignore[misc] and changing converter key --- .../dynamo/conversion/aten_ops_converters.py | 26 +++++++------------ .../dynamo/conversion/impl/grid.py | 10 +------ 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5103b35921..184acae1cf 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -330,11 +330,7 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,), @@ -349,21 +345,19 @@ def aten_ops_grid( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.grid.grid( - ctx, - target, - SourceIR.ATEN, - name, - input=args[0], - grid=args[1], - interpolation_mode=args[2], - padding_mode=args[3], + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + grid=args[1], + interpolation_mode=args[2], + padding_mode=args[3], align_corners=args_bounds_check(args, 4, True), - output_mask=args_bounds_check(args, 5, None), - ) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] def aten_ops_relu( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py index af4e1ecabc..672fc97351 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -34,7 +34,6 @@ def grid( interpolation_mode: int, padding_mode: int, align_corners: bool, - output_mask: Optional[Sequence[bool]] = None, ) -> TRTTensor: grid_layer = ctx.net.add_grid_sample(input, grid) assert interpolation_mode in GridSamplerInterpolationMode @@ -45,11 +44,4 @@ def grid( grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None) grid_layer.align_corners = align_corners set_layer_name(grid_layer, target, name + "_grid_layer", source_ir) - if output_mask is None: - return grid_layer.get_output(0) - elif output_mask[0] and output_mask[1]: - return (grid_layer.get_output(0), None) - elif output_mask[0]: - return grid_layer.get_output(0) - else: - return None + return grid_layer.get_output(0) From cb6168f7705e3ca71b5ef815bd51588ce03d98c4 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 20 Oct 2023 18:02:31 -0700 Subject: [PATCH 05/10] adding grid_sampler_2d and grid_sampler_3d --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 184acae1cf..f7bf15dc13 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -331,6 +331,8 @@ def aten_ops_fmod( @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,), From 5c49595617174977135e53f28385976fc6edc24e Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 23 Oct 2023 12:36:06 -0700 Subject: [PATCH 06/10] Removing optional arg for align_corner --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f7bf15dc13..af1bcb6c44 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -355,7 +355,7 @@ def aten_ops_grid( grid=args[1], interpolation_mode=args[2], padding_mode=args[3], - align_corners=args_bounds_check(args, 4, True), + align_corners=args[4], ) From b298710c206d31995a3fc33f06ee954310101460 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 8 Nov 2023 14:01:06 -0800 Subject: [PATCH 07/10] Adding grid_sampler 2d cases (no 3d cases) --- .../dynamo/conversion/aten_ops_converters.py | 3 +- tests/py/dynamo/conversion/test_grid_aten.py | 203 ++++++++++++------ 2 files changed, 139 insertions(+), 67 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index af1bcb6c44..4b31eafe0b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -332,7 +332,8 @@ def aten_ops_fmod( @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc] +# commented this for now, see py/dynamo/conversion/tests/test_grid_aten. Should this be removed altogether? +# @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py index 5ac615c78c..9833165023 100644 --- a/tests/py/dynamo/conversion/test_grid_aten.py +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -1,88 +1,159 @@ import pytest import torch import torch.nn as nn -from .harness import DispatchTestCase +from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +grid_sampler_ops = [ + ( + "input_grid_interpolation_nearest_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + # The 3d cases with 4d input gives the error that it requires 5d input for both input and grid + # The 5d input fails in the generation of the Grid Layer since the TensorRT layer requires 4d input + # ("input_grid_interpolation_nearest_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_nearest_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_nearest_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_linear_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_linear_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_linear_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_cubic_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_cubic_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), + # ("input_grid_interpolation_cubic_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), +] + class TestGridConverter(DispatchTestCase): @parameterized.expand( [ ( - "input_grid_interpolation_nearest_sample_fill", - [1, 1, 5, 5], - [1, 5, 2, 2], - 0, - 0, - ), - ( - "input_grid_interpolation_nearest_sample_clamp", - [1, 1, 5, 5], - [1, 5, 2, 2], - 0, - 1, - ), - ( - "input_grid_interpolation_nearest_sample_reflect", - [1, 1, 5, 5], - [1, 5, 2, 2], - 0, - 2, - ), - ( - "input_grid_interpolation_linear_sample_fill", - [1, 1, 5, 5], - [1, 5, 2, 2], - 1, - 0, - ), - ( - "input_grid_interpolation_linear_sample_clamp", - [1, 1, 5, 5], - [1, 5, 2, 2], - 1, - 1, - ), - ( - "input_grid_interpolation_linear_sample_reflect", - [1, 1, 5, 5], - [1, 5, 2, 2], - 1, - 2, - ), - ( - "input_grid_interpolation_cubic_sample_fill", - [1, 1, 5, 5], - [1, 5, 2, 2], - 2, - 0, - ), - ( - "input_grid_interpolation_cubic_sample_clamp", - [1, 1, 5, 5], - [1, 5, 2, 2], - 2, - 1, - ), - ( - "input_grid_interpolation_cubic_sample_reflect", - [1, 1, 5, 5], - [1, 5, 2, 2], - 2, - 2, - ), + grid_sampler_op[0], + grid_sampler_op[1], + grid_sampler_op[2], + grid_sampler_op[3], + ) + for grid_sampler_op in grid_sampler_ops ] ) - def test_grid(self, _, input_shape, dim_shape, interpolation, sample): + def test_grid(self, _, op, input_shape, dim_shape): class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + def forward(self, x): grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) - return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True) + return self.grid_sampler_op(x, grid) inputs = [torch.randn(input_shape, dtype=torch.float32)] - self.run_test(TestModule(), inputs) + grid_model = TestModule(op) + self.run_test(grid_model, inputs) if __name__ == "__main__": From 783f7ef27261aa090615386cbe0d53c28f2158e2 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Nov 2023 15:39:51 -0800 Subject: [PATCH 08/10] Removing the misc and removing the grid_sampler.3d cases --- .../dynamo/conversion/aten_ops_converters.py | 10 ++++------ tests/py/dynamo/conversion/test_grid_aten.py | 11 ----------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4b31eafe0b..db99ec45cd 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -330,16 +330,14 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) # type: ignore[misc] -# commented this for now, see py/dynamo/conversion/tests/test_grid_aten. Should this be removed altogether? -# @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) @enforce_tensor_types( { 0: (TRTTensor,), 1: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_grid( ctx: ConversionContext, target: Target, @@ -360,7 +358,7 @@ def aten_ops_grid( ) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) def aten_ops_relu( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py index 9833165023..3475c41a71 100644 --- a/tests/py/dynamo/conversion/test_grid_aten.py +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -115,17 +115,6 @@ [1, 1, 5, 5], [1, 5, 2, 2], ), - # The 3d cases with 4d input gives the error that it requires 5d input for both input and grid - # The 5d input fails in the generation of the Grid Layer since the TensorRT layer requires 4d input - # ("input_grid_interpolation_nearest_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_nearest_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_nearest_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_linear_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_linear_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_linear_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_cubic_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_cubic_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), - # ("input_grid_interpolation_cubic_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]), ] From 39f5d312496b92dc049e33dc089a630fb83c5dd4 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Nov 2023 15:55:05 -0800 Subject: [PATCH 09/10] Removing ignore[misc] for addmm, tile, permute --- .../dynamo/conversion/aten_ops_converters.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index db99ec45cd..459c69bfa5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -782,12 +782,12 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) @enforce_tensor_types( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_tile( ctx: ConversionContext, target: Target, @@ -805,7 +805,7 @@ def aten_ops_tile( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { 0: (TRTTensor,), @@ -2018,14 +2018,14 @@ def aten_ops_argmax( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) @enforce_tensor_types( { 0: (TRTTensor,), 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) # type: ignore[misc] +) def aten_ops_addmm( ctx: ConversionContext, target: Target, From a6d25c41c6c721f51edec72ca1a2268b3f32ed10 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 13 Nov 2023 09:46:22 -0800 Subject: [PATCH 10/10] Changing .harness of grid test case --- tests/py/dynamo/conversion/test_grid_aten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py index 3475c41a71..32480110f3 100644 --- a/tests/py/dynamo/conversion/test_grid_aten.py +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -1,7 +1,7 @@ import pytest import torch import torch.nn as nn -from harness import DispatchTestCase +from .harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input