diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 67ce83469f..2a77db5636 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -51,6 +51,24 @@ def aten_ops_batch_norm( ) +@dynamo_tensorrt_converter(torch.ops.aten.cat.default) +def aten_ops_cat( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.cat.cat( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args_bounds_check(args, 1, 0), + ) + + def embedding_param_validator(embedding_node: Node) -> bool: scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) sparse = args_bounds_check(embedding_node.args, 4) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 3f49377619..e1c135b422 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -4,6 +4,7 @@ activation, attention, cast, + cat, condition, conv, deconv, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py new file mode 100644 index 0000000000..24149d01b0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -0,0 +1,34 @@ +from typing import Dict, Optional, Sequence, Union + +import numpy as np +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + get_positive_dim, + get_trt_tensor, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def cat( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], + dim: int, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + trt_inputs = [] + for each_input in input: + if not isinstance(each_input, TRTTensor): + each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}") + trt_inputs.append(each_input) + concat_layer = ctx.net.add_concatenation(trt_inputs) + dim = get_positive_dim(dim, len(input[0].shape)) + concat_layer.axis = dim + set_layer_name(concat_layer, target, name + "_gather", source_ir) + return concat_layer.get_output(0)