From 380b69830f1986737923f68486e81008b80f8218 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 13 Nov 2023 12:14:26 -0800 Subject: [PATCH 01/11] Gather in impl --- .../dynamo/conversion/aten_ops_converters.py | 25 +++++++++++++++++ .../dynamo/conversion/impl/select.py | 27 ++++++++++++------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 478cf98dea..2bbd3c7915 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -183,6 +183,31 @@ def aten_ops_native_group_norm( ) +@dynamo_tensorrt_converter(torch.ops.aten.gather) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.gather( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args[1], + index=args[2], + sparse_grad=args_bounds_check(args, 4, False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) @dynamo_tensorrt_converter(torch.ops.aten.group_norm) @enforce_tensor_types( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index db586be65f..bfc671426f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -68,13 +68,26 @@ def select( indices_tensor = ctx.net.add_constant( index_value.shape, to_numpy(index_value) ).get_output(0) - layer = ctx.net.add_gather(input, indices_tensor, dim) - out = layer.get_output(0) + out = gather(input, indices_tensor, dim) if len(out.shape) != 1: layer = ctx.net.add_shuffle(out) return layer.get_output(0) +def gather( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> TRTTensor: + gather_layer = ctx.net.add_gather(input, index, dim) + set_layer_name(gather_layer, target, name + "_gather", source_ir) + return gather_layer.get_output(0) + + def index( ctx: ConversionContext, target: Target, @@ -127,9 +140,7 @@ def index( ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") - gather_layer = ctx.net.add_gather(input, indices_tensor, index) - set_layer_name(gather_layer, target, name + "_index_gather", source_ir) - return gather_layer.get_output(0) + return gather(input, index, indices_tensor) else: input_shape = input.shape _LOGGER.debug(f"The input shape is {input.shape}") @@ -242,11 +253,7 @@ def index( dim_tensor_list[adv_indx_indices[i]], ) - gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) - set_layer_name( - gather_layer_element, target, name + "_index_gather_element", source_ir - ) - gather_out = gather_layer_element.get_output(0) + gather_out = gather(flatten_tensor, cum_adv_index, 0) _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}") From af104031d5f85d56e21f81dae54943569a3632d5 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Nov 2023 10:46:02 -0800 Subject: [PATCH 02/11] changing gather signature --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- py/torch_tensorrt/dynamo/conversion/impl/select.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2bbd3c7915..fc6d353f7e 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -183,7 +183,7 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.gather) +@dynamo_tensorrt_converter(torch.ops.aten.gather.default) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index bfc671426f..cc5e43e5b8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -68,7 +68,7 @@ def select( indices_tensor = ctx.net.add_constant( index_value.shape, to_numpy(index_value) ).get_output(0) - out = gather(input, indices_tensor, dim) + out = gather(ctx, target, source_ir, name, input, indices_tensor, dim) if len(out.shape) != 1: layer = ctx.net.add_shuffle(out) return layer.get_output(0) @@ -140,7 +140,7 @@ def index( ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") - return gather(input, index, indices_tensor) + return gather(ctx, target, source_ir, name, input, index, indices_tensor) else: input_shape = input.shape _LOGGER.debug(f"The input shape is {input.shape}") @@ -253,7 +253,7 @@ def index( dim_tensor_list[adv_indx_indices[i]], ) - gather_out = gather(flatten_tensor, cum_adv_index, 0) + gather_out = gather(ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index) _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}") From 99e7de9327466f87ac9500a5dcab91e1c2e8cbd4 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 28 Nov 2023 14:07:11 -0800 Subject: [PATCH 03/11] Linting fix and adding test case --- .../dynamo/conversion/impl/select.py | 4 ++- .../py/dynamo/conversion/test_gather_aten.py | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/py/dynamo/conversion/test_gather_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index cc5e43e5b8..6ee966bf3a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -253,7 +253,9 @@ def index( dim_tensor_list[adv_indx_indices[i]], ) - gather_out = gather(ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index) + gather_out = gather( + ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index + ) _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}") diff --git a/tests/py/dynamo/conversion/test_gather_aten.py b/tests/py/dynamo/conversion/test_gather_aten.py new file mode 100644 index 0000000000..d31fa08332 --- /dev/null +++ b/tests/py/dynamo/conversion/test_gather_aten.py @@ -0,0 +1,30 @@ +import operator + +import torch +import torch.nn as nn +from harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + + +class TestIndexConverter(DispatchTestCase): + def test_index_zero_two_dim(self): + class TestModule(nn.Module): + def __init__(self): + self.index0 = torch.randint(0, 1, (1, 1)) + super().__init__() + + def forward(self, x): + index0 = torch.randint(0, 1, (1, 1)) + indices = [None, self.index0] + out = torch.ops.aten.gather.default(x, 0, indices) + return out + + input = [torch.randn(2, 2)] + self.run_test( + TestModule(), + input, + ) + +if __name__ == "__main__": + run_tests() \ No newline at end of file From e194877000b30261802a29a679093ec6128904c6 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 12 Dec 2023 14:19:41 -0800 Subject: [PATCH 04/11] linting error fix --- tests/py/dynamo/conversion/test_gather_aten.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_gather_aten.py b/tests/py/dynamo/conversion/test_gather_aten.py index d31fa08332..9064f9517c 100644 --- a/tests/py/dynamo/conversion/test_gather_aten.py +++ b/tests/py/dynamo/conversion/test_gather_aten.py @@ -26,5 +26,6 @@ def forward(self, x): input, ) + if __name__ == "__main__": - run_tests() \ No newline at end of file + run_tests() From b253bd16790e461567a71ea57d8bade798d3906d Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 19 Dec 2023 13:42:37 -0800 Subject: [PATCH 05/11] Correct sparse arg in aten::gather --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/impl/select.py | 1 + tests/py/dynamo/conversion/test_gather_aten.py | 18 ++++++++++-------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index fc6d353f7e..2a3083dd57 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -204,7 +204,7 @@ def aten_ops_gather( input=args[0], dim=args[1], index=args[2], - sparse_grad=args_bounds_check(args, 4, False), + sparse_grad=args_bounds_check(args, 3, False), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6ee966bf3a..cab381fa2d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -82,6 +82,7 @@ def gather( input: TRTTensor, dim: int, index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], + sparse_grad: bool = False, ) -> TRTTensor: gather_layer = ctx.net.add_gather(input, index, dim) set_layer_name(gather_layer, target, name + "_gather", source_ir) diff --git a/tests/py/dynamo/conversion/test_gather_aten.py b/tests/py/dynamo/conversion/test_gather_aten.py index 9064f9517c..0317e23b94 100644 --- a/tests/py/dynamo/conversion/test_gather_aten.py +++ b/tests/py/dynamo/conversion/test_gather_aten.py @@ -2,25 +2,27 @@ import torch import torch.nn as nn -from harness import DispatchTestCase +from .harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -class TestIndexConverter(DispatchTestCase): - def test_index_zero_two_dim(self): +class TestGatherConverter(DispatchTestCase): + def test_gather_zero_two_dim(self): class TestModule(nn.Module): def __init__(self): - self.index0 = torch.randint(0, 1, (1, 1)) + # self.index0 = torch.randint(0, 1, (1, 1)) super().__init__() - def forward(self, x): - index0 = torch.randint(0, 1, (1, 1)) - indices = [None, self.index0] + def forward(self, x, indices): + # index0 = torch.randint(0, 1, (1, 1)) + # indices = [None, self.index0] out = torch.ops.aten.gather.default(x, 0, indices) return out - input = [torch.randn(2, 2)] + index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32) + indices = [None, index0] + input = [torch.randn(2, 2), index0] self.run_test( TestModule(), input, From 3d2ada6684089e508fc5554b83b87bc931503010 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 19 Dec 2023 14:01:49 -0800 Subject: [PATCH 06/11] converting indices to fp32 tensors --- py/torch_tensorrt/dynamo/conversion/impl/select.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index cab381fa2d..3ba6c79b18 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -84,7 +84,13 @@ def gather( index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], sparse_grad: bool = False, ) -> TRTTensor: - gather_layer = ctx.net.add_gather(input, index, dim) + indices_tensor = [] + + for i, ind in enumerate(index): + indices_tensor.append(get_trt_tensor( + ctx, ind, name + f"_parameter_to_fp32_tensor_{i}" + )) + gather_layer = ctx.net.add_gather(input, indices_tensor, dim) set_layer_name(gather_layer, target, name + "_gather", source_ir) return gather_layer.get_output(0) From 4ad27918924b7afd467f7b23441da58594abefa1 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 19 Dec 2023 14:05:32 -0800 Subject: [PATCH 07/11] Linting fix --- py/torch_tensorrt/dynamo/conversion/impl/select.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 3ba6c79b18..5e8f005f69 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -87,9 +87,9 @@ def gather( indices_tensor = [] for i, ind in enumerate(index): - indices_tensor.append(get_trt_tensor( - ctx, ind, name + f"_parameter_to_fp32_tensor_{i}" - )) + indices_tensor.append( + get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") + ) gather_layer = ctx.net.add_gather(input, indices_tensor, dim) set_layer_name(gather_layer, target, name + "_gather", source_ir) return gather_layer.get_output(0) From a0ff737cc08b388ce42e17d5463d8448cb84db58 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 5 Jan 2024 13:06:25 -0800 Subject: [PATCH 08/11] Correcting index test cas --- .../dynamo/conversion/impl/select.py | 17 +++++++++-------- tests/py/dynamo/conversion/test_gather_aten.py | 9 ++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 5e8f005f69..17d2189b7e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -9,6 +9,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcastable, + cast_trt_tensor, get_positive_dim, get_trt_tensor, to_numpy, @@ -81,16 +82,16 @@ def gather( name: str, input: TRTTensor, dim: int, - index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], + index: Union[TRTTensor, np.ndarray, torch.Tensor], sparse_grad: bool = False, ) -> TRTTensor: - indices_tensor = [] - - for i, ind in enumerate(index): - indices_tensor.append( - get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") - ) - gather_layer = ctx.net.add_gather(input, indices_tensor, dim) + if not isinstance(index, TRTTensor): + index = get_trt_tensor(ctx, index, name + f"_parameter_to_fp32_tensor") + # This is for the case where torch.ops.aten.gather requires torch.int64 + # However TRTInterpreter complains that torch.int64 is not a supported type + # So the below cast does not help + # index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir) + gather_layer = ctx.net.add_gather(input, index, dim) set_layer_name(gather_layer, target, name + "_gather", source_ir) return gather_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_gather_aten.py b/tests/py/dynamo/conversion/test_gather_aten.py index 0317e23b94..9d939c13bb 100644 --- a/tests/py/dynamo/conversion/test_gather_aten.py +++ b/tests/py/dynamo/conversion/test_gather_aten.py @@ -2,26 +2,25 @@ import torch import torch.nn as nn -from .harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from .harness import DispatchTestCase + class TestGatherConverter(DispatchTestCase): def test_gather_zero_two_dim(self): class TestModule(nn.Module): def __init__(self): - # self.index0 = torch.randint(0, 1, (1, 1)) super().__init__() def forward(self, x, indices): # index0 = torch.randint(0, 1, (1, 1)) - # indices = [None, self.index0] out = torch.ops.aten.gather.default(x, 0, indices) return out - index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32) - indices = [None, index0] + # index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32) + index0 = torch.randint(0, 1, (1, 1)) input = [torch.randn(2, 2), index0] self.run_test( TestModule(), From b304233cf0b4786d181cde46ab411a76696feb6f Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 20 Feb 2024 12:45:56 -0800 Subject: [PATCH 09/11] removing gather test --- .../py/dynamo/conversion/test_gather_aten.py | 32 ------------------- 1 file changed, 32 deletions(-) delete mode 100644 tests/py/dynamo/conversion/test_gather_aten.py diff --git a/tests/py/dynamo/conversion/test_gather_aten.py b/tests/py/dynamo/conversion/test_gather_aten.py deleted file mode 100644 index 9d939c13bb..0000000000 --- a/tests/py/dynamo/conversion/test_gather_aten.py +++ /dev/null @@ -1,32 +0,0 @@ -import operator - -import torch -import torch.nn as nn -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input - -from .harness import DispatchTestCase - - -class TestGatherConverter(DispatchTestCase): - def test_gather_zero_two_dim(self): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, indices): - # index0 = torch.randint(0, 1, (1, 1)) - out = torch.ops.aten.gather.default(x, 0, indices) - return out - - # index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32) - index0 = torch.randint(0, 1, (1, 1)) - input = [torch.randn(2, 2), index0] - self.run_test( - TestModule(), - input, - ) - - -if __name__ == "__main__": - run_tests() From 9f751c2fed16564b67ef1c2a52be0768a4b22809 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 22 Feb 2024 13:06:53 -0800 Subject: [PATCH 10/11] select converter implementation correction --- py/torch_tensorrt/dynamo/conversion/impl/select.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 17d2189b7e..7e45b02025 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -69,10 +69,7 @@ def select( indices_tensor = ctx.net.add_constant( index_value.shape, to_numpy(index_value) ).get_output(0) - out = gather(ctx, target, source_ir, name, input, indices_tensor, dim) - if len(out.shape) != 1: - layer = ctx.net.add_shuffle(out) - return layer.get_output(0) + return gather(ctx, target, source_ir, name, input, dim, indices_tensor) def gather( @@ -93,6 +90,9 @@ def gather( # index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir) gather_layer = ctx.net.add_gather(input, index, dim) set_layer_name(gather_layer, target, name + "_gather", source_ir) + out = gather_layer.get_output(0) + if len(out.shape) != 1: + gather_layer = ctx.net.add_shuffle(out) return gather_layer.get_output(0) From a3e65861847e5f8752899d0c0cffbcb23a66d96a Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 7 Mar 2024 16:58:06 -0800 Subject: [PATCH 11/11] removing aten ops gather and changing th e gather tensor --- .../dynamo/conversion/aten_ops_converters.py | 25 ------------------- .../dynamo/conversion/impl/select.py | 3 +-- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2a3083dd57..478cf98dea 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -183,31 +183,6 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.gather.default) -@enforce_tensor_types( - { - 0: (TRTTensor,), - } -) # type: ignore[misc] -def aten_ops_gather( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.select.gather( - ctx, - target, - SourceIR.ATEN, - name, - input=args[0], - dim=args[1], - index=args[2], - sparse_grad=args_bounds_check(args, 3, False), - ) - - @dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) @dynamo_tensorrt_converter(torch.ops.aten.group_norm) @enforce_tensor_types( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 7e45b02025..ae0fda9897 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -80,10 +80,9 @@ def gather( input: TRTTensor, dim: int, index: Union[TRTTensor, np.ndarray, torch.Tensor], - sparse_grad: bool = False, ) -> TRTTensor: if not isinstance(index, TRTTensor): - index = get_trt_tensor(ctx, index, name + f"_parameter_to_fp32_tensor") + index = get_trt_tensor(ctx, index, name + f"_index_to_fp32_tensor") # This is for the case where torch.ops.aten.gather requires torch.int64 # However TRTInterpreter complains that torch.int64 is not a supported type # So the below cast does not help