From 00ab80cd641e69e9aeed5bbeae67bea3a269712a Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 20 Feb 2023 14:33:14 -0800 Subject: [PATCH] refactor(//py/torch_tensorrt/fx/converters): Reorging the converters so that the three tracing paths call down into a common converter base instead of across each other --- .../fx/converters/acc_ops_converters.py | 110 +------- py/torch_tensorrt/fx/converters/activation.py | 242 +++++++++++++++--- .../fx/converters/aten_ops_converters.py | 3 +- .../fx/converters/nn_ops_converters.py | 23 ++ 4 files changed, 236 insertions(+), 142 deletions(-) create mode 100644 py/torch_tensorrt/fx/converters/nn_ops_converters.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 77a9b92dfe..3470962a5e 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,6 +26,7 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous +import activation _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -1004,9 +1005,7 @@ def acc_ops_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.RELU - return add_activation_layer(network, input_val, operation_type, target, name) + return activation.add_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.leaky_relu) @@ -1017,12 +1016,7 @@ def acc_ops_leaky_relu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - negative_slope = kwargs["negative_slope"] - operation_type = trt.ActivationType.LEAKY_RELU - return add_activation_layer( - network, input_val, operation_type, target, name, negative_slope - ) + return activation.add_leaky_relu(network, target, kwargs, name) @tensorrt_converter(acc_ops.elu) @@ -1033,11 +1027,7 @@ def acc_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - alpha = kwargs["alpha"] - operation_type = trt.ActivationType.ELU - return add_activation_layer(network, input_val, operation_type, target, name, alpha) - + return activation.add_elu(network, target, kwargs, name) @tensorrt_converter(acc_ops.selu) def acc_ops_selu( @@ -1047,9 +1037,7 @@ def acc_ops_selu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SELU - return add_activation_layer(network, input_val, operation_type, target, name) + return activation.add_selu(network, target, kwargs, name) @tensorrt_converter(acc_ops.softsign) @@ -1060,10 +1048,7 @@ def acc_ops_softsign( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SOFTSIGN - return add_activation_layer(network, input_val, operation_type, target, name) - + return activation.add_softsign(network, target, kwargs, name) @tensorrt_converter(acc_ops.sin) def acc_ops_sin( @@ -1138,10 +1123,7 @@ def acc_ops_tanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.TANH - return add_activation_layer(network, input_val, operation_type, target, name) - + return activation.add_tanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.asin) def acc_ops_asin( @@ -3129,23 +3111,7 @@ def acc_ops_hard_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Hard sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, - input_val, - trt.ActivationType.HARD_SIGMOID, - target, - name, - alpha=1 / 6, - beta=0.5, - ) + return activation.add_hard_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.sigmoid) @@ -3156,17 +3122,7 @@ def acc_ops_sigmoid( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, input_val, trt.ActivationType.SIGMOID, target, name - ) + return activation.add_sigmoid(network, target, kwargs, name) @tensorrt_converter(acc_ops.permute) @@ -3367,34 +3323,7 @@ def acc_ops_gelu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - approximate = kwargs["approximate"] - if approximate != "none": - raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"GELU received input {input_val} that is not part " - "of the TensorRT region!" - ) - if network.has_implicit_batch_dimension: - raise RuntimeError( - "GeLU converter currently doesn't support implicit batch dimension" - ) - - plugin_name = "CustomGeluPluginDynamic" - # type_id 0 for float32, 1 for float16 - type_id = trt.PluginField( - "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 - ) - field_collection = TRTPluginFieldCollection([type_id]) - plugin_version = "1" - - plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) - - layer = network.add_plugin_v2([input_val], plugin) - set_layer_name(layer, target, name) - return layer.get_output(0) - + return activation.add_gelu(network, target, kwargs, name) @tensorrt_converter(acc_ops.chunk) def acc_ops_chunk( @@ -3549,24 +3478,7 @@ def acc_ops_hardtanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"hardtanh received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return add_activation_layer( - network, - input_val, - trt.ActivationType.CLIP, - target, - name, - alpha=kwargs["min_val"], - beta=kwargs["max_val"], - ) - + return activation.add_hardtanh(network, target, kwargs, name) @tensorrt_converter(acc_ops.interpolate) def acc_ops_interpolate( diff --git a/py/torch_tensorrt/fx/converters/activation.py b/py/torch_tensorrt/fx/converters/activation.py index a7ab25152c..20b88270c1 100644 --- a/py/torch_tensorrt/fx/converters/activation.py +++ b/py/torch_tensorrt/fx/converters/activation.py @@ -1,64 +1,206 @@ +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch +from torch.fx.node import Argument, Target + +from ..types import ( + Shape, + TRTDataType, + TRTElementWiseOp, + TRTLayer, + TRTNetwork, + TRTPlugin, + TRTPluginFieldCollection, + TRTTensor, +) +from ..utils import torch_dtype_from_trt + +def add_activation_layer( + network: TRTNetwork, + input_val: TRTTensor, + operation_type: trt.ActivationType, + target: Target, + name: str, + alpha: Optional[Any] = None, + beta: Optional[Any] = None, + dyn_range_fn: Optional[Callable[Tuple[float, float]]] = None +) -> TRTTensor: + """ + Add a TensorRT Activation layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the activation op. + Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT activation + operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + alpha (Optional[Any]): If not None, we will use it to set the alpha + attribute of the created TensorRT activation layer. + beta (Optional[Any]): If not None, we will use it to set the beta + attribute of the created TensorRT activation layer. + dyn_range_fn: Optional[Callable[Tuple[float, float]]]: A function which takes the dynamic range of a TensorRT Tensor and returns the output dynamic range + + + Returns: + The output of TensorRT Activation layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_activation(input_val, operation_type) + if alpha is not None: + layer.alpha = alpha + if beta is not None: + layer.beta = beta + set_layer_name(layer, target, name) + + if input_val.dynamic_range is not None: + dyn_range = dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) -from ..converter_registry import tensorrt_converter - -from .converter_utils import mark_as_int8_layer + return layer.get_output(0) +def add_elu( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + alpha = kwargs["alpha"] + operation_type = trt.ActivationType.ELU + return add_activation_layer(network, input_val, operation_type, target, name, alpha) + +def add_gelu( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + approximate = kwargs["approximate"] + if approximate != "none": + raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"GELU received input {input_val} that is not part " + "of the TensorRT region!" + ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "GeLU converter currently doesn't support implicit batch dimension" + ) -def common_activation( - network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name -): - layer = network.add_activation(input=input_val, type=activation_type) - layer.name = layer_name + plugin_name = "CustomGeluPluginDynamic" + # type_id 0 for float32, 1 for float16 + type_id = trt.PluginField( + "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 + ) + field_collection = TRTPluginFieldCollection([type_id]) + plugin_version = "1" - if input_val.dynamic_range: - dyn_range = activation_dyn_range_fn(input_val.dynamic_range) - mark_as_int8_layer(layer, dyn_range) + plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) + layer = network.add_plugin_v2([input_val], plugin) + set_layer_name(layer, target, name) return layer.get_output(0) +def add_hard_sigmoid( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + return add_activation_layer( + network, + input_val, + trt.ActivationType.HARD_SIGMOID, + target, + name, + alpha=1 / 6, + beta=0.5, + ) -@tensorrt_converter(torch.nn.functional.relu) -@tensorrt_converter(torch.nn.modules.activation.ReLU) -def relu(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 +def add_hardtanh( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: input_val = kwargs["input"] - if not isinstance(input_val, trt.tensorrt.ITensor): + if not isinstance(input_val, TRTTensor): raise RuntimeError( - f"ReLU received input {input_val} that is not part " + f"hardtanh received input {input_val} that is not part " "of the TensorRT region!" ) - def activation_dyn_range_fn(dyn_range): - return max(0, dyn_range[0]), max(0, dyn_range[1]) - - return common_activation( + return add_activation_layer( network, - submod, input_val, - trt.ActivationType.RELU, - activation_dyn_range_fn, - layer_name, + trt.ActivationType.CLIP, + target, + name, + alpha=kwargs["min_val"], + beta=kwargs["max_val"], ) -@tensorrt_converter(torch.nn.modules.activation.Sigmoid) -def sigmoid(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 +def add_leaky_relu( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: input_val = kwargs["input"] + negative_slope = kwargs["negative_slope"] + operation_type = trt.ActivationType.LEAKY_RELU + return add_activation_layer( + network, input_val, operation_type, target, name, negative_slope + ) - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"Sigmoid received input {input_val} that is not part " - "of the TensorRT region!" - ) +def add_relu( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + operation_type = trt.ActivationType.RELU + + def activation_dyn_range_fn(dyn_range): + return max(0, dyn_range[0]), max(0, dyn_range[1]) + + return add_activation_layer(network, input_val, operation_type, target, name, dyn_range_fn=activation_dyn_range_fn) + +def add_selu( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + operation_type = trt.ActivationType.SELU + return add_activation_layer(network, input_val, operation_type, target, name) + +def add_sigmoid( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] def activation_dyn_range_fn(dyn_range): def sigmoid_fn(x): @@ -66,11 +208,27 @@ def sigmoid_fn(x): return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1]) - return common_activation( - network, - submod, - input_val, - trt.ActivationType.SIGMOID, - activation_dyn_range_fn, - layer_name, + return add_activation_layer( + network, input_val, trt.ActivationType.SIGMOID, target, name, dyn_range_fn=activation_dyn_range_fn ) + + +def add_softsign( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + operation_type = trt.ActivationType.SOFTSIGN + return add_activation_layer(network, input_val, operation_type, target, name) + +def add_tanh( + network: TRTNetwork, + target: Target, + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input_val = kwargs["input"] + operation_type = trt.ActivationType.TANH + return add_activation_layer(network, input_val, operation_type, target, name) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 943eb203b3..5f69bdf603 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -22,6 +22,7 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +import activation _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -280,7 +281,7 @@ def aten_ops_relu( kwargs_new = { "input": args[0], } - return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name) + return activation.add_relu(network, target, kwargs_new, name) @tensorrt_converter(torch.ops.aten.sub.Tensor) diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py new file mode 100644 index 0000000000..9745039bd9 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -0,0 +1,23 @@ +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch + +from ..converter_registry import tensorrt_converter + +from .converter_utils import mark_as_int8_layer +import activation + +@tensorrt_converter(torch.nn.functional.relu) +@tensorrt_converter(torch.nn.modules.activation.ReLU) +def relu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return activation.add_relu(network,"tensorrt", kwargs, layer_name) + +@tensorrt_converter(torch.nn.modules.activation.Sigmoid) +def sigmoid(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return activation.add_sigmoid(network,"tensorrt", kwargs, layer_name) \ No newline at end of file