@@ -68,13 +68,26 @@ def select(
68
68
indices_tensor = ctx .net .add_constant (
69
69
index_value .shape , to_numpy (index_value )
70
70
).get_output (0 )
71
- layer = ctx .net .add_gather (input , indices_tensor , dim )
72
- out = layer .get_output (0 )
71
+ out = gather (input , indices_tensor , dim )
73
72
if len (out .shape ) != 1 :
74
73
layer = ctx .net .add_shuffle (out )
75
74
return layer .get_output (0 )
76
75
77
76
77
+ def gather (
78
+ ctx : ConversionContext ,
79
+ target : Target ,
80
+ source_ir : Optional [SourceIR ],
81
+ name : str ,
82
+ input : TRTTensor ,
83
+ dim : int ,
84
+ index : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
85
+ ) -> TRTTensor :
86
+ gather_layer = ctx .net .add_gather (input , index , dim )
87
+ set_layer_name (gather_layer , target , name + "_gather" , source_ir )
88
+ return gather_layer .get_output (0 )
89
+
90
+
78
91
def index (
79
92
ctx : ConversionContext ,
80
93
target : Target ,
@@ -127,9 +140,7 @@ def index(
127
140
)
128
141
index = adv_indx_indices [0 ]
129
142
_LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
130
- gather_layer = ctx .net .add_gather (input , indices_tensor , index )
131
- set_layer_name (gather_layer , target , name + "_index_gather" , source_ir )
132
- return gather_layer .get_output (0 )
143
+ return gather (input , index , indices_tensor )
133
144
else :
134
145
input_shape = input .shape
135
146
_LOGGER .debug (f"The input shape is { input .shape } " )
@@ -242,11 +253,7 @@ def index(
242
253
dim_tensor_list [adv_indx_indices [i ]],
243
254
)
244
255
245
- gather_layer_element = ctx .net .add_gather (flatten_tensor , cum_adv_index , 0 )
246
- set_layer_name (
247
- gather_layer_element , target , name + "_index_gather_element" , source_ir
248
- )
249
- gather_out = gather_layer_element .get_output (0 )
256
+ gather_out = gather (flatten_tensor , cum_adv_index , 0 )
250
257
_LOGGER .debug (f"The shape after cumultative gather is { gather_out .shape } " )
251
258
_LOGGER .debug (f"The shape for cumulative adv index is { cum_adv_index } " )
252
259
0 commit comments