diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index db586be65f..ae0fda9897 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -9,6 +9,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcastable, + cast_trt_tensor, get_positive_dim, get_trt_tensor, to_numpy, @@ -68,11 +69,30 @@ def select( indices_tensor = ctx.net.add_constant( index_value.shape, to_numpy(index_value) ).get_output(0) - layer = ctx.net.add_gather(input, indices_tensor, dim) - out = layer.get_output(0) + return gather(ctx, target, source_ir, name, input, dim, indices_tensor) + + +def gather( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: Union[TRTTensor, np.ndarray, torch.Tensor], +) -> TRTTensor: + if not isinstance(index, TRTTensor): + index = get_trt_tensor(ctx, index, name + f"_index_to_fp32_tensor") + # This is for the case where torch.ops.aten.gather requires torch.int64 + # However TRTInterpreter complains that torch.int64 is not a supported type + # So the below cast does not help + # index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir) + gather_layer = ctx.net.add_gather(input, index, dim) + set_layer_name(gather_layer, target, name + "_gather", source_ir) + out = gather_layer.get_output(0) if len(out.shape) != 1: - layer = ctx.net.add_shuffle(out) - return layer.get_output(0) + gather_layer = ctx.net.add_shuffle(out) + return gather_layer.get_output(0) def index( @@ -127,9 +147,7 @@ def index( ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") - gather_layer = ctx.net.add_gather(input, indices_tensor, index) - set_layer_name(gather_layer, target, name + "_index_gather", source_ir) - return gather_layer.get_output(0) + return gather(ctx, target, source_ir, name, input, index, indices_tensor) else: input_shape = input.shape _LOGGER.debug(f"The input shape is {input.shape}") @@ -242,11 +260,9 @@ def index( dim_tensor_list[adv_indx_indices[i]], ) - gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) - set_layer_name( - gather_layer_element, target, name + "_index_gather_element", source_ir + gather_out = gather( + ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index ) - gather_out = gather_layer_element.get_output(0) _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")