From f69668d87f9606029bf82685c447ce28c5fa11cd Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 22 May 2024 12:49:44 +0900 Subject: [PATCH 1/4] feat: support aten.as_strided converter --- .../dynamo/conversion/aten_ops_converters.py | 20 ++++++ .../dynamo/conversion/impl/slice/ops.py | 58 +++++++++++++++++ .../dynamo/conversion/test_as_strided_aten.py | 65 +++++++++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_as_strided_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 15a993668b..66a25b4d2c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2206,6 +2206,26 @@ def aten_ops_cdist_forward( ) +@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default) +def aten_ops_as_strided( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.as_strided( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + size=args[1], + stride=args[2], + storage_offset=args_bounds_check(args, 3, None), + ) + + def avg_pool_param_validator(pool_node: Node) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 61d71fe9a0..44936b733c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -3,11 +3,13 @@ import numpy as np import tensorrt as trt +import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + flatten_dims, get_positive_dim, get_trt_tensor, ) @@ -259,3 +261,59 @@ def flip( ) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def as_strided( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + size: Sequence[int], + stride: Sequence[int], + storage_offset: int, +) -> TRTTensor: + assert len(size) == len(stride), "size and stride shapes must be the same" + + flatten_shape = flatten_dims(input, 0, -1) + flatten_output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, flatten_shape + ) + + indices = [] + + # Recursive function to compute indices for as_strided operation + def nested( + rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int + ) -> None: + if ( + dim == rank + ): # If the current dimension equals the rank, append the computed index + indices.append(current) + return + for i in range(size[dim]): # Recursively compute indices across dimensions + nested( + rank, size, stride, current + stride[dim] * i, dim + 1 + ) # Calculate the index for the current dimension and recursively explore further dimensions + + nested(len(size), size, stride, storage_offset, 0) + + indices = torch.tensor(indices, dtype=torch.int) + + indices_tensor = get_trt_tensor(ctx, (indices), f"{name}_indices") + + # Use gather to reorder elements based on computed indices + gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0) + gather_output = gather_layer.get_output(0) + + # Reshape the gathered tensor to the desired size + reshape_output = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape", + gather_output, + tuple(size), + ) + + return reshape_output diff --git a/tests/py/dynamo/conversion/test_as_strided_aten.py b/tests/py/dynamo/conversion/test_as_strided_aten.py new file mode 100644 index 0000000000..9437a04566 --- /dev/null +++ b/tests/py/dynamo/conversion/test_as_strided_aten.py @@ -0,0 +1,65 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestAsStridedConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + (5, 5), + (2, 3), + (1, 2), + 0, + ), + ( + (5, 5), + (2, 3), + (2, 2), + 1, + ), + ( + (20, 20), + (2, 3, 2), + (2, 2, 2), + 0, + ), + ( + (8, 8, 8), + (2, 2, 3), + (1, 2, 2), + 1, + ), + ( + (200, 200, 200), + (9, 9, 3, 2), + (2, 2, 2, 3), + 1, + ), + ] + ) + def test_as_strided( + self, + input_shape, + output_size, + stride, + storage_offset=0, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.as_strided.default( + x, output_size, stride, storage_offset + ) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From b4fce3e9c824d50ede7e2d71ec9a85c36fb4d560 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 22 May 2024 12:50:39 +0900 Subject: [PATCH 2/4] feat: validator for zero shape, add test case --- .../dynamo/conversion/aten_ops_converters.py | 15 +++++++++++- .../dynamo/conversion/impl/slice/ops.py | 15 ++++++------ .../dynamo/conversion/test_as_strided_aten.py | 24 +++++++++++++++++++ 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 66a25b4d2c..e6898a3e4f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2206,7 +2206,20 @@ def aten_ops_cdist_forward( ) -@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default) +def zero_output_validator(node: Node) -> bool: + if 0 in node.args[1]: + _LOGGER.debug( + f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." + ) + return False + else: + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.as_strided.default, + capability_validator=zero_output_validator, +) def aten_ops_as_strided( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 44936b733c..139ecc1149 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -3,7 +3,6 @@ import numpy as np import tensorrt as trt -import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -271,13 +270,15 @@ def as_strided( input: TRTTensor, size: Sequence[int], stride: Sequence[int], - storage_offset: int, + storage_offset: Optional[int], ) -> TRTTensor: - assert len(size) == len(stride), "size and stride shapes must be the same" + # Ensure storage_offset is an integer before passing to nested + if storage_offset is None: + storage_offset = 0 flatten_shape = flatten_dims(input, 0, -1) flatten_output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape", input, flatten_shape + ctx, target, source_ir, f"{name}_reshape_flatten_output", input, flatten_shape ) indices = [] @@ -298,9 +299,9 @@ def nested( nested(len(size), size, stride, storage_offset, 0) - indices = torch.tensor(indices, dtype=torch.int) + indices = np.array(indices, dtype=np.int32) - indices_tensor = get_trt_tensor(ctx, (indices), f"{name}_indices") + indices_tensor = get_trt_tensor(ctx, indices, f"{name}_indices") # Use gather to reorder elements based on computed indices gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0) @@ -311,7 +312,7 @@ def nested( ctx, target, source_ir, - f"{name}_reshape", + f"{name}_reshape_gather_output", gather_output, tuple(size), ) diff --git a/tests/py/dynamo/conversion/test_as_strided_aten.py b/tests/py/dynamo/conversion/test_as_strided_aten.py index 9437a04566..ba723bf4f9 100644 --- a/tests/py/dynamo/conversion/test_as_strided_aten.py +++ b/tests/py/dynamo/conversion/test_as_strided_aten.py @@ -39,6 +39,30 @@ class TestAsStridedConverter(DispatchTestCase): (2, 2, 2, 3), 1, ), + ( + (10, 25, 12), + (3, 7, 3), + (2, 1, 3), + 1, + ), + ( + (10, 25, 12), + (3, 7, 3), + (2, 0, 3), + 1, + ), + ( + (10, 25, 12, 100), + (6, 5, 7, 10), + (0, 0, 0, 0), + 0, + ), + ( + (10, 25, 12, 100), + (6, 5, 7, 10), + (0, 0, 0, 0), + 1, + ), ] ) def test_as_strided( From d3441689719cd4e0c5e1abbe4a164a7cd287c27c Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 22 May 2024 12:53:58 +0900 Subject: [PATCH 3/4] chore: move functions to organize code better --- .../dynamo/conversion/aten_ops_converters.py | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e6898a3e4f..297eb70853 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -800,6 +800,40 @@ def aten_ops_tile( ) +def zero_output_validator(node: Node) -> bool: + if 0 in node.args[1]: + _LOGGER.debug( + f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." + ) + return False + else: + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.as_strided.default, + capability_validator=zero_output_validator, +) +@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default) +def aten_ops_as_strided( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.as_strided( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + size=args[1], + stride=args[2], + storage_offset=args_bounds_check(args, 3, None), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { @@ -2185,7 +2219,6 @@ def aten_ops_linear( bias=args_bounds_check(args, 2, None), ) - @dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default) def aten_ops_cdist_forward( ctx: ConversionContext, @@ -2206,39 +2239,6 @@ def aten_ops_cdist_forward( ) -def zero_output_validator(node: Node) -> bool: - if 0 in node.args[1]: - _LOGGER.debug( - f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." - ) - return False - else: - return True - - -@dynamo_tensorrt_converter( - torch.ops.aten.as_strided.default, - capability_validator=zero_output_validator, -) -def aten_ops_as_strided( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.slice.as_strided( - ctx, - target, - source_ir=SourceIR.ATEN, - name=name, - input=args[0], - size=args[1], - stride=args[2], - storage_offset=args_bounds_check(args, 3, None), - ) - - def avg_pool_param_validator(pool_node: Node) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) From 38cb51eaf98de792ae23b72507ca9ef366d30c06 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 22 May 2024 13:34:36 +0900 Subject: [PATCH 4/4] chore: resolve linting error --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 297eb70853..11a213551d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2219,6 +2219,7 @@ def aten_ops_linear( bias=args_bounds_check(args, 2, None), ) + @dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default) def aten_ops_cdist_forward( ctx: ConversionContext,