From 944278ea54d9d2f3f71a82cb50dbd2cebb6e65af Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 25 Sep 2023 12:57:14 -0700 Subject: [PATCH 1/4] aten::cat converter moving to impl --- .../dynamo/conversion/aten_ops_converters.py | 18 ++++++++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/cat.py | 29 +++++++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/cat.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 67ce83469f..5c978a45ed 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -51,6 +51,24 @@ def aten_ops_batch_norm( ) +@dynamo_tensorrt_converter(torch.ops.aten.cat.default) +def aten_ops_cat( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.cat.cat( + network, + target, + SourceIR.ATEN, + name, + tensors=args[0], + dim = args_bounds_check(args, 2, 1), + ) + + def embedding_param_validator(embedding_node: Node) -> bool: scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) sparse = args_bounds_check(embedding_node.args, 4) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 3f49377619..e1c135b422 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -4,6 +4,7 @@ activation, attention, cast, + cat, condition, conv, deconv, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py new file mode 100644 index 0000000000..7492e5223b --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -0,0 +1,29 @@ +from typing import Optional, Union, Sequence, Dict + +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +def cat( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTNetwork, + dim: int, + +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr] + raise RuntimeError( + f"cat received inputs {input} that is not part " "of the TensorRT region!" + ) + concat_layer = network.add_concatenation(input) + if dim < 0: + dim = len(input[0].shape) + dim + + concat_layer.axis = dim + set_layer_name(concat_layer, target, name + "_gather", source_ir) + return concat_layer.get_output(0) \ No newline at end of file From c2ac3b47a4409448016b317dd678a7590436d005 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 25 Sep 2023 12:59:28 -0700 Subject: [PATCH 2/4] aten::cat converter moving to impl --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- py/torch_tensorrt/dynamo/conversion/impl/cat.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5c978a45ed..6253b98866 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -65,7 +65,7 @@ def aten_ops_cat( SourceIR.ATEN, name, tensors=args[0], - dim = args_bounds_check(args, 2, 1), + dim=args_bounds_check(args, 2, 1), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 7492e5223b..4dd19e2690 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Sequence, Dict +from typing import Dict, Optional, Sequence, Union import torch from torch.fx.node import Target @@ -6,6 +6,7 @@ from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + def cat( network: TRTNetwork, target: Target, @@ -13,9 +14,7 @@ def cat( name: str, input: TRTNetwork, dim: int, - ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr] raise RuntimeError( f"cat received inputs {input} that is not part " "of the TensorRT region!" @@ -26,4 +25,4 @@ def cat( concat_layer.axis = dim set_layer_name(concat_layer, target, name + "_gather", source_ir) - return concat_layer.get_output(0) \ No newline at end of file + return concat_layer.get_output(0) From 759b260e3fb04f6cd860365ffa519eacbd3311ef Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 4 Oct 2023 13:13:39 -0700 Subject: [PATCH 3/4] aten::cat rebase changes --- .../dynamo/conversion/aten_ops_converters.py | 8 ++++---- py/torch_tensorrt/dynamo/conversion/impl/cat.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6253b98866..2a77db5636 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -53,19 +53,19 @@ def aten_ops_batch_norm( @dynamo_tensorrt_converter(torch.ops.aten.cat.default) def aten_ops_cat( - network: TRTNetwork, + ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.cat.cat( - network, + ctx, target, SourceIR.ATEN, name, - tensors=args[0], - dim=args_bounds_check(args, 2, 1), + input=args[0], + dim=args_bounds_check(args, 1, 0), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 4dd19e2690..9c33d2665b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -4,22 +4,23 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.types import TRTNetwork, TRTTensor def cat( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTNetwork, + input: Union[TRTTensor, Sequence[TRTTensor]], dim: int, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr] - raise RuntimeError( - f"cat received inputs {input} that is not part " "of the TensorRT region!" - ) - concat_layer = network.add_concatenation(input) + for each_input in input: + if(not isinstance(each_input, TRTTensor)): + each_input = get_trt_tensor(each_input) + concat_layer = ctx.net.add_concatenation(input) if dim < 0: dim = len(input[0].shape) + dim From 83e4f3ca9d8a6db55aa2d5a2b31eec968ebd71bb Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 4 Oct 2023 15:39:52 -0700 Subject: [PATCH 4/4] Addressing review comments --- .../dynamo/conversion/impl/cat.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 9c33d2665b..24149d01b0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,11 +1,16 @@ from typing import Dict, Optional, Sequence, Union +import numpy as np import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + get_positive_dim, + get_trt_tensor, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor @@ -14,16 +19,16 @@ def cat( target: Target, source_ir: Optional[SourceIR], name: str, - input: Union[TRTTensor, Sequence[TRTTensor]], + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], dim: int, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + trt_inputs = [] for each_input in input: - if(not isinstance(each_input, TRTTensor)): - each_input = get_trt_tensor(each_input) - concat_layer = ctx.net.add_concatenation(input) - if dim < 0: - dim = len(input[0].shape) + dim - + if not isinstance(each_input, TRTTensor): + each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}") + trt_inputs.append(each_input) + concat_layer = ctx.net.add_concatenation(trt_inputs) + dim = get_positive_dim(dim, len(input[0].shape)) concat_layer.axis = dim set_layer_name(concat_layer, target, name + "_gather", source_ir) return concat_layer.get_output(0)