Skip to content

Commit 6126e66

Browse files
committed
Gather in impl
1 parent 7011809 commit 6126e66

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,31 @@ def aten_ops_native_group_norm(
183183
)
184184

185185

186+
@dynamo_tensorrt_converter(torch.ops.aten.gather)
187+
@enforce_tensor_types(
188+
{
189+
0: (TRTTensor,),
190+
}
191+
) # type: ignore[misc]
192+
def aten_ops_gather(
193+
ctx: ConversionContext,
194+
target: Target,
195+
args: Tuple[Argument, ...],
196+
kwargs: Dict[str, Argument],
197+
name: str,
198+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
199+
return impl.select.gather(
200+
ctx,
201+
target,
202+
SourceIR.ATEN,
203+
name,
204+
input=args[0],
205+
dim=args[1],
206+
index=args[2],
207+
sparse_grad = args_bounds_check(args, 4, False),
208+
)
209+
210+
186211
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default)
187212
@dynamo_tensorrt_converter(torch.ops.aten.group_norm)
188213
@enforce_tensor_types(

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,26 @@ def select(
6868
indices_tensor = ctx.net.add_constant(
6969
index_value.shape, to_numpy(index_value)
7070
).get_output(0)
71-
layer = ctx.net.add_gather(input, indices_tensor, dim)
72-
out = layer.get_output(0)
71+
out = gather(input, indices_tensor, dim)
7372
if len(out.shape) != 1:
7473
layer = ctx.net.add_shuffle(out)
7574
return layer.get_output(0)
7675

7776

77+
def gather(
78+
ctx: ConversionContext,
79+
target: Target,
80+
source_ir: Optional[SourceIR],
81+
name: str,
82+
input: TRTTensor,
83+
dim: int,
84+
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
85+
) -> TRTTensor:
86+
gather_layer = ctx.net.add_gather(input, index, dim)
87+
set_layer_name(gather_layer, target, name + "_gather", source_ir)
88+
return gather_layer.get_output(0)
89+
90+
7891
def index(
7992
ctx: ConversionContext,
8093
target: Target,
@@ -127,9 +140,7 @@ def index(
127140
)
128141
index = adv_indx_indices[0]
129142
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
130-
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
131-
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
132-
return gather_layer.get_output(0)
143+
return gather(input, index, indices_tensor)
133144
else:
134145
input_shape = input.shape
135146
_LOGGER.debug(f"The input shape is {input.shape}")
@@ -242,11 +253,7 @@ def index(
242253
dim_tensor_list[adv_indx_indices[i]],
243254
)
244255

245-
gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
246-
set_layer_name(
247-
gather_layer_element, target, name + "_index_gather_element", source_ir
248-
)
249-
gather_out = gather_layer_element.get_output(0)
256+
gather_out = gather(flatten_tensor, cum_adv_index, 0)
250257
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
251258
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")
252259

0 commit comments

Comments
 (0)