@@ -81,30 +81,20 @@ def index(
81
81
source_ir : Optional [SourceIR ],
82
82
name : str ,
83
83
input : TRTTensor ,
84
- index : Union [
85
- TRTTensor ,
86
- Sequence [TRTTensor ],
87
- np .ndarray ,
88
- Sequence [np .ndarray ],
89
- torch .Tensor ,
90
- Sequence [torch .Tensor ],
91
- ],
84
+ index : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
92
85
) -> TRTTensor :
93
86
adv_indx_indices = []
94
87
tensor_indices = []
95
- # _LOGGER.debug(f"The index shape is {index.shape}")
96
88
# check if the input is dynamic
97
89
dynamic_shape = has_dynamic_shape (input .shape )
98
90
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
99
91
# If any is not this flag will be set to False
100
- is_numpy = True
101
- _LOGGER .debug (f"Checking for the is_numpy flag" )
102
- for i , ind in enumerate (index ):
103
- if ind is None :
104
- continue
105
- if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
106
- is_numpy = False
107
- break
92
+ _LOGGER .debug (
93
+ f"Determining whether aten.index constant-index optimization can be invoked"
94
+ )
95
+ is_numpy = all (
96
+ isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
97
+ )
108
98
# here we need to check if all the index are broadcastable
109
99
# if no, then we need to broadcast
110
100
last_index = None
@@ -117,7 +107,6 @@ def index(
117
107
# other cases are kept as TRTTensor
118
108
if is_numpy :
119
109
ind = to_numpy (ind )
120
- is_numpy = True
121
110
else :
122
111
ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
123
112
if last_index is not None :
@@ -156,9 +145,7 @@ def index(
156
145
for i in range (rank ):
157
146
dim = input_shape [i ]
158
147
dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
159
- # dim_tensor_list is a list of tensors or numpy
160
- if is_numpy :
161
- dim_list .append (dim )
148
+ # dim_tensor_list is a list of tensors
162
149
dim_tensor_list .append (dim_tensor )
163
150
164
151
# for cases like
@@ -211,12 +198,12 @@ def index(
211
198
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
212
199
# // j dimension of input x.
213
200
if is_numpy :
214
- multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
201
+ multiplier = input_shape [adv_indx_indices [adv_indx_count - 1 ]]
215
202
cum_adv_index = tensor_indices [adv_indx_count - 1 ]
216
203
for i in range (adv_indx_count - 2 , - 1 , - 1 ):
217
204
adv_index = multiplier * tensor_indices [i ]
218
205
cum_adv_index = cum_adv_index + adv_index
219
- multiplier = multiplier * dim_list [adv_indx_indices [i ]]
206
+ multiplier = multiplier * input_shape [adv_indx_indices [i ]]
220
207
cum_adv_index = get_trt_tensor (
221
208
ctx , cum_adv_index , name + f"_index_sum_intermediate"
222
209
)
0 commit comments