Skip to content

Gather Implementation #2457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
38 changes: 27 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
to_numpy,
Expand Down Expand Up @@ -68,11 +69,30 @@ def select(
indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
return gather(ctx, target, source_ir, name, input, dim, indices_tensor)


def gather(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
) -> TRTTensor:
if not isinstance(index, TRTTensor):
index = get_trt_tensor(ctx, index, name + f"_index_to_fp32_tensor")
# This is for the case where torch.ops.aten.gather requires torch.int64
# However TRTInterpreter complains that torch.int64 is not a supported type
# So the below cast does not help
# index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir)
Comment on lines +86 to +89
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this issue still occur in the test cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does. aten.scatter has similar use cases, so I am working on that. The casting of nodes in the TRT test infrastructure can be used here to get over. This is the PR- #2664.

gather_layer = ctx.net.add_gather(input, index, dim)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
out = gather_layer.get_output(0)
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)
return layer.get_output(0)
gather_layer = ctx.net.add_shuffle(out)
return gather_layer.get_output(0)


def index(
Expand Down Expand Up @@ -127,9 +147,7 @@ def index(
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
return gather(ctx, target, source_ir, name, input, index, indices_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 126/147 needs to be renamed since it uses name + f"_parameter_to_fp32_tensor", which also appears in the gather function. This could cause a duplicate name error in edge cases

else:
input_shape = input.shape
_LOGGER.debug(f"The input shape is {input.shape}")
Expand Down Expand Up @@ -242,11 +260,9 @@ def index(
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
gather_layer_element, target, name + "_index_gather_element", source_ir
gather_out = gather(
ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index
)
gather_out = gather_layer_element.get_output(0)
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")

Expand Down