Skip to content

Commit 97e3758

Browse files
committed
Addressing review comments
1 parent 0fc9c75 commit 97e3758

File tree

1 file changed

+10
-23
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+10
-23
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,20 @@ def index(
8181
source_ir: Optional[SourceIR],
8282
name: str,
8383
input: TRTTensor,
84-
index: Union[
85-
TRTTensor,
86-
Sequence[TRTTensor],
87-
np.ndarray,
88-
Sequence[np.ndarray],
89-
torch.Tensor,
90-
Sequence[torch.Tensor],
91-
],
84+
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
9285
) -> TRTTensor:
9386
adv_indx_indices = []
9487
tensor_indices = []
95-
# _LOGGER.debug(f"The index shape is {index.shape}")
9688
# check if the input is dynamic
9789
dynamic_shape = has_dynamic_shape(input.shape)
9890
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9991
# If any is not this flag will be set to False
100-
is_numpy = True
101-
_LOGGER.debug(f"Checking for the is_numpy flag")
102-
for i, ind in enumerate(index):
103-
if ind is None:
104-
continue
105-
if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)):
106-
is_numpy = False
107-
break
92+
_LOGGER.debug(
93+
f"Determining whether aten.index constant-index optimization can be invoked"
94+
)
95+
is_numpy = all(
96+
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
97+
)
10898
# here we need to check if all the index are broadcastable
10999
# if no, then we need to broadcast
110100
last_index = None
@@ -117,7 +107,6 @@ def index(
117107
# other cases are kept as TRTTensor
118108
if is_numpy:
119109
ind = to_numpy(ind)
120-
is_numpy = True
121110
else:
122111
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
123112
if last_index is not None:
@@ -156,9 +145,7 @@ def index(
156145
for i in range(rank):
157146
dim = input_shape[i]
158147
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
159-
# dim_tensor_list is a list of tensors or numpy
160-
if is_numpy:
161-
dim_list.append(dim)
148+
# dim_tensor_list is a list of tensors
162149
dim_tensor_list.append(dim_tensor)
163150

164151
# for cases like
@@ -211,12 +198,12 @@ def index(
211198
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
212199
# // j dimension of input x.
213200
if is_numpy:
214-
multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]]
201+
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
215202
cum_adv_index = tensor_indices[adv_indx_count - 1]
216203
for i in range(adv_indx_count - 2, -1, -1):
217204
adv_index = multiplier * tensor_indices[i]
218205
cum_adv_index = cum_adv_index + adv_index
219-
multiplier = multiplier * dim_list[adv_indx_indices[i]]
206+
multiplier = multiplier * input_shape[adv_indx_indices[i]]
220207
cum_adv_index = get_trt_tensor(
221208
ctx, cum_adv_index, name + f"_index_sum_intermediate"
222209
)

0 commit comments

Comments
 (0)