From 430e17d6d4043416b41af0b75d396344f40f34b9 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 21 Sep 2023 17:33:36 -0700 Subject: [PATCH] feat: support deconv (1d, 2d, and Nd) dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 53 +++-- .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/deconv.py | 140 +++++++++++ .../conversion/test_deconvolution_aten.py | 224 ++++++++++++++++++ 4 files changed, 400 insertions(+), 18 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/deconv.py create mode 100644 tests/py/dynamo/conversion/test_deconvolution_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 19f273ba3f..502fbdabfd 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -357,14 +357,14 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) -) +) # type: ignore[misc] @dynamo_tensorrt_converter( torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) -) +) # type: ignore[misc] @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, capability_validator=dynamic_unsupported_with_args([1]), -) +) # type: ignore[misc] def aten_ops_split( network: TRTNetwork, target: Target, @@ -1331,7 +1331,7 @@ def aten_ops_less( def conv_param_validator(conv_node: Node) -> bool: - return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0])) + return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @dynamo_tensorrt_converter( @@ -1344,20 +1344,37 @@ def aten_ops_convolution( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.conv.convNd( - network, - target, - source_ir=SourceIR.ATEN, - name=name, - is_conv1d=len(args[3]) == 1, - input=args[0], - weight=args[1], - bias=args[2], - stride=args[3], - padding=args[4], - dilation=args[5], - groups=args[8], - ) + is_transposed = args[6] + if not is_transposed: + return impl.conv.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=len(args[3]) == 1, + input=args[0], + weight=args[1], + bias=args[2], + stride=args[3], + padding=args[4], + dilation=args[5], + groups=args[8], + ) + else: + return impl.deconv.deconvNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_deconv1d=len(args[3]) == 1, + input=args[0], + weight=args[1], + bias=args[2], + stride=args[3], + padding=args[4], + dilation=args[5], + groups=args[8], + ) @dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc] diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index e615599eb4..8477b6449b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -5,6 +5,7 @@ cast, condition, conv, + deconv, elementwise, embedding, linear, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py new file mode 100644 index 0000000000..e0f5844bd7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -0,0 +1,140 @@ +from typing import Optional, Sequence, Union + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion.converter_utils import ( + extend_attr_to_tuple, + get_trt_tensor, +) +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + get_dyn_range, + has_dynamic_shape, + mark_as_int8_layer, + set_layer_name, + to_numpy, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def deconvNd( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + is_deconv1d: bool, + input: TRTTensor, + weight: Union[TRTTensor, torch.Tensor, np.ndarray], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + stride: Optional[Union[int, Sequence[int]]], + padding: Optional[Union[int, Sequence[int]]], + groups: Optional[int], + dilation: Optional[Union[int, Sequence[int]]], + scale: Optional[Union[torch.Tensor, float]] = None, + zero_point: Optional[Union[torch.Tensor, float]] = None, +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for deconvolution." + + if is_deconv1d: + # Apply an unsqueeze operation to transform the deconv1d problem into deconv2d + input = impl.unsqueeze.unsqueeze( + network, target, source_ir, name + "_unsqueeze_deconv1d", input, -1 + ) + + # Process bias terms + if isinstance(bias, (torch.Tensor, np.ndarray)): + # Transform the bias constant into a Numpy array + bias = to_numpy(bias) + + elif isinstance(bias, TRTTensor): + bias = get_trt_tensor(network, bias, f"{name}_bias") + + elif bias is not None: + raise RuntimeError( + f"Deconvolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor" + ) + + # Process weight terms + if network.has_explicit_precision or isinstance(weight, TRTTensor): + weight = get_trt_tensor(network, weight, f"{name}_weight") + # Append new dimension (unsqueeze) if the deconvolution is 1d + if is_deconv1d: + input = impl.unsqueeze.unsqueeze( + network, target, source_ir, name + "_unsqueeze_weight", weight, -1 + ) + + elif isinstance(weight, (torch.Tensor, np.ndarray)): + # Transform the weight constant into a Numpy array + weight = to_numpy(weight) + + # Append new dimension (unsqueeze) if the deconvolution is 1d + if is_deconv1d: + weight = np.expand_dims(weight, axis=-1) + + else: + raise RuntimeError( + f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]" + ) + + # add deconv layer + deconv_layer = network.add_deconvolution_nd( + input=input, + num_output_maps=weight.shape[0], + kernel_shape=weight.shape[2:], + kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, + bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, + ) + + # If the weight is a TRTTensor, set it as an input of the layer + if isinstance(weight, TRTTensor): + deconv_layer.set_input(1, weight) + + # If the bias is a TRTTensor, set it as an input of the layer + if isinstance(bias, TRTTensor): + deconv_layer.set_input(2, bias) + + # Cast certain fields to tuples, in accordance with TRT requirements + padding = (padding,) if isinstance(padding, int) else padding + stride = (stride,) if isinstance(stride, int) else stride + dilation = (dilation,) if isinstance(dilation, int) else dilation + + # Expand parameters manually for Conv1D computations + if is_deconv1d: + padding = (tuple(padding) + (0,)) if padding is not None else padding + stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride + dilation = ( + extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation + ) + + set_layer_name(deconv_layer, target, name, source_ir) + + # Set relevant attributes of deconvolution layer + if padding is not None: + deconv_layer.padding_nd = padding + if stride is not None: + deconv_layer.stride_nd = stride + if dilation is not None: + deconv_layer.dilation_nd = dilation + if groups is not None: + deconv_layer.num_groups = groups + + # Handle quantization cases + if scale is not None and zero_point is not None: + # Assume the dtype of activation is torch.quint8 + mark_as_int8_layer(deconv_layer, get_dyn_range(scale, zero_point, torch.quint8)) + + result = deconv_layer.get_output(0) + + if is_deconv1d: + # Apply a squeeze operation to transform the deconv2d problem back into deconv1d + result = impl.squeeze.squeeze( + network, target, source_ir, name + "_squeeze_deconv1d", result, -1 + ) + + return result diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py new file mode 100644 index 0000000000..939a7ea9c0 --- /dev/null +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -0,0 +1,224 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestDeconvolutionConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_deconv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.deconv = torch.nn.ConvTranspose1d( + in_channels=3, + out_channels=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + return self.deconv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.convolution.default}, + ) + + def test_deconv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.deconv = torch.nn.ConvTranspose1d( + in_channels=3, + out_channels=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + return self.deconv(x) + + input_specs = [ + Input( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_deconv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.deconv = torch.nn.ConvTranspose2d( + in_channels=3, + out_channels=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + return self.deconv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for deconvolution. + + def test_deconv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.deconv = torch.nn.ConvTranspose2d(3, 3, 1) + + def forward(self, x): + return self.deconv(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_deconv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.deconv = torch.nn.ConvTranspose3d( + in_channels=3, + out_channels=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + return self.deconv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for deconvolution. + + def test_deconv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.deconv = torch.nn.ConvTranspose3d(3, 3, 1) + + def forward(self, x): + return self.deconv(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + +if __name__ == "__main__": + run_tests()