Skip to content

Commit cec3835

Browse files
authored
feat: support aten.index_select converter (#2710)
1 parent 821ff91 commit cec3835

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2829,3 +2829,28 @@ def aten_ops_roll(
28292829
args[1],
28302830
args_bounds_check(args, 2, []),
28312831
)
2832+
2833+
2834+
@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
2835+
@enforce_tensor_types(
2836+
{
2837+
0: (TRTTensor,),
2838+
2: (TRTTensor,),
2839+
}
2840+
)
2841+
def aten_ops_index_select(
2842+
ctx: ConversionContext,
2843+
target: Target,
2844+
args: Tuple[Argument, ...],
2845+
kwargs: Dict[str, Argument],
2846+
name: str,
2847+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2848+
return impl.select.index_select(
2849+
ctx,
2850+
target,
2851+
SourceIR.ATEN,
2852+
name,
2853+
args[0],
2854+
args[1],
2855+
args[2],
2856+
)

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def index(
9090
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9191
# If any is not this flag will be set to False
9292
_LOGGER.debug(
93-
f"Determining whether aten.index constant-index optimization can be invoked"
93+
"Determining whether aten.index constant-index optimization can be invoked"
9494
)
9595
is_numpy = all(
9696
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
@@ -123,7 +123,7 @@ def index(
123123
return identity_layer.get_output(0)
124124
elif len(tensor_indices) == 1:
125125
indices_tensor = get_trt_tensor(
126-
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
126+
ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
127127
)
128128
index = adv_indx_indices[0]
129129
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
@@ -204,7 +204,7 @@ def index(
204204
cum_adv_index = cum_adv_index + adv_index
205205
multiplier = multiplier * input_shape[adv_indx_indices[i]]
206206
cum_adv_index = get_trt_tensor(
207-
ctx, cum_adv_index, name + f"_index_sum_intermediate"
207+
ctx, cum_adv_index, name + "_index_sum_intermediate"
208208
)
209209
else:
210210
multiplier = get_trt_tensor(
@@ -263,7 +263,7 @@ def index(
263263
adv_indx_count
264264
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
265265
):
266-
_LOGGER.debug(f"The indices are continuous in this case")
266+
_LOGGER.debug("The indices are continuous in this case")
267267
concat_tensor_reshape.append(
268268
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
269269
)
@@ -287,7 +287,7 @@ def index(
287287
source_ir,
288288
)
289289
unfold_tensor = regular_index_shuffle_layer.get_output(0)
290-
_LOGGER.debug(f"The tensor is unfolded now")
290+
_LOGGER.debug("The tensor is unfolded now")
291291
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
292292

293293
# Transpose folded advanced indexed axis to its original location.
@@ -342,7 +342,7 @@ def index(
342342
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
343343

344344
else:
345-
_LOGGER.debug(f"The indices are not continuous in this case")
345+
_LOGGER.debug("The indices are not continuous in this case")
346346
concat_final_tensor = []
347347
concat_final_tensor.append(cum_adv_index_shape_tensor)
348348
for i in range(0, rank):
@@ -370,3 +370,21 @@ def index(
370370
reshape_output = reshape_layer.get_output(0)
371371

372372
return reshape_output
373+
374+
375+
def index_select(
376+
ctx: ConversionContext,
377+
target: Target,
378+
source_ir: Optional[SourceIR],
379+
name: str,
380+
input: TRTTensor,
381+
dim: int,
382+
index: TRTTensor,
383+
) -> TRTTensor:
384+
# The axis parameter specifies the dimension along which to index.
385+
dim = get_positive_dim(dim, len(input.shape))
386+
gather_layer = ctx.net.add_gather(input, index, axis=dim)
387+
388+
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
389+
390+
return gather_layer.get_output(0)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestIndexSelectConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("1d_input", (10,), 0, (1,)),
13+
("2d_input_dim_0", (10, 3), 0, (0, 2)),
14+
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
15+
("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)),
16+
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
17+
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
18+
("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)),
19+
("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)),
20+
]
21+
)
22+
def test_index_select(self, _, source_shape, dim, indices_val):
23+
class TestIndexSelect(torch.nn.Module):
24+
def forward(self, source_tensor, indices_tensor):
25+
return torch.ops.aten.index_select.default(
26+
source_tensor, dim, indices_tensor
27+
)
28+
29+
input = [
30+
torch.randn(*source_shape, dtype=torch.float32),
31+
torch.tensor([*indices_val], dtype=torch.int32),
32+
]
33+
34+
self.run_test(
35+
TestIndexSelect(),
36+
input,
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)