From fb91a7559f09f267ebb585d9df962ad66486ca0e Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 30 Oct 2023 12:36:18 -0700 Subject: [PATCH 1/2] Index ITensor test --- tests/py/dynamo/conversion/test_index_aten.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 393eb53c63..de9b843a89 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -5,7 +5,7 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from .harness import DispatchTestCase +from harness import DispatchTestCase class TestIndexConverter(DispatchTestCase): @@ -26,6 +26,21 @@ def forward(self, x): TestModule(), input, ) + + def test_index_zero_two_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test( + TestModule(), + [input, index0], + ) def test_index_zero_index_three_dim(self): class TestModule(nn.Module): @@ -43,6 +58,21 @@ def forward(self, x): TestModule(), input, ) + + def test_index_zero_index_three_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0, None] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test( + TestModule(), + [input, index0] + ) def test_index_zero_index_one_index_two_three_dim(self): class TestModule(nn.Module): From a6885cd7550f66b86aeb2fe1e091da920ddb47db Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 30 Oct 2023 13:00:09 -0700 Subject: [PATCH 2/2] Adding impl function for gather --- .../dynamo/conversion/aten_ops_converters.py | 25 +++++++++++++++++++ .../dynamo/conversion/impl/select.py | 25 +++++++++++++------ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 70c4574b94..1e46d1d7d9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -147,6 +147,31 @@ def aten_ops_native_group_norm( ) +@dynamo_tensorrt_converter(torch.ops.aten.gather) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.group_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args[1], + index=args[2], + sparse_grad = args_bounds_check(args, 4, False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.group_norm) # type: ignore[misc] @enforce_tensor_types( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index db586be65f..bf64fd3475 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -75,6 +75,20 @@ def select( return layer.get_output(0) +def gather( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> TRTTensor: + gather_layer = ctx.net.add_gather(input, index, index) + set_layer_name(gather_layer, target, name + "_gather", source_ir) + return gather_layer.get_output(0) + + def index( ctx: ConversionContext, target: Target, @@ -127,9 +141,8 @@ def index( ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") - gather_layer = ctx.net.add_gather(input, indices_tensor, index) - set_layer_name(gather_layer, target, name + "_index_gather", source_ir) - return gather_layer.get_output(0) + return gather(input, index, indices_tensor) + else: input_shape = input.shape _LOGGER.debug(f"The input shape is {input.shape}") @@ -242,11 +255,7 @@ def index( dim_tensor_list[adv_indx_indices[i]], ) - gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) - set_layer_name( - gather_layer_element, target, name + "_index_gather_element", source_ir - ) - gather_out = gather_layer_element.get_output(0) + gather_out = gather(flatten_tensor, 0, cum_adv_index) _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")