@@ -90,7 +90,7 @@ def index(
90
90
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
91
91
# If any is not this flag will be set to False
92
92
_LOGGER .debug (
93
- f "Determining whether aten.index constant-index optimization can be invoked"
93
+ "Determining whether aten.index constant-index optimization can be invoked"
94
94
)
95
95
is_numpy = all (
96
96
isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
@@ -123,7 +123,7 @@ def index(
123
123
return identity_layer .get_output (0 )
124
124
elif len (tensor_indices ) == 1 :
125
125
indices_tensor = get_trt_tensor (
126
- ctx , tensor_indices [0 ], name + f "_parameter_to_fp32_tensor"
126
+ ctx , tensor_indices [0 ], name + "_parameter_to_fp32_tensor"
127
127
)
128
128
index = adv_indx_indices [0 ]
129
129
_LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
@@ -204,7 +204,7 @@ def index(
204
204
cum_adv_index = cum_adv_index + adv_index
205
205
multiplier = multiplier * input_shape [adv_indx_indices [i ]]
206
206
cum_adv_index = get_trt_tensor (
207
- ctx , cum_adv_index , name + f "_index_sum_intermediate"
207
+ ctx , cum_adv_index , name + "_index_sum_intermediate"
208
208
)
209
209
else :
210
210
multiplier = get_trt_tensor (
@@ -263,7 +263,7 @@ def index(
263
263
adv_indx_count
264
264
== adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
265
265
):
266
- _LOGGER .debug (f "The indices are continuous in this case" )
266
+ _LOGGER .debug ("The indices are continuous in this case" )
267
267
concat_tensor_reshape .append (
268
268
get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
269
269
)
@@ -287,7 +287,7 @@ def index(
287
287
source_ir ,
288
288
)
289
289
unfold_tensor = regular_index_shuffle_layer .get_output (0 )
290
- _LOGGER .debug (f "The tensor is unfolded now" )
290
+ _LOGGER .debug ("The tensor is unfolded now" )
291
291
_LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
292
292
293
293
# Transpose folded advanced indexed axis to its original location.
@@ -342,7 +342,7 @@ def index(
342
342
reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
343
343
344
344
else :
345
- _LOGGER .debug (f "The indices are not continuous in this case" )
345
+ _LOGGER .debug ("The indices are not continuous in this case" )
346
346
concat_final_tensor = []
347
347
concat_final_tensor .append (cum_adv_index_shape_tensor )
348
348
for i in range (0 , rank ):
@@ -370,3 +370,21 @@ def index(
370
370
reshape_output = reshape_layer .get_output (0 )
371
371
372
372
return reshape_output
373
+
374
+
375
+ def index_select (
376
+ ctx : ConversionContext ,
377
+ target : Target ,
378
+ source_ir : Optional [SourceIR ],
379
+ name : str ,
380
+ input : TRTTensor ,
381
+ dim : int ,
382
+ index : TRTTensor ,
383
+ ) -> TRTTensor :
384
+ # The axis parameter specifies the dimension along which to index.
385
+ dim = get_positive_dim (dim , len (input .shape ))
386
+ gather_layer = ctx .net .add_gather (input , index , axis = dim )
387
+
388
+ set_layer_name (gather_layer , target , f"{ name } _gather" , source_ir )
389
+
390
+ return gather_layer .get_output (0 )
0 commit comments