diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 19f273ba3f..dd37d72815 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1378,3 +1378,90 @@ def aten_ops_linear( weight=args[1], bias=args_bounds_check(args, 2, None), ) + + +def avg_pool_param_validator(pool_node: Node) -> bool: + ceil_mode = args_bounds_check(pool_node.args, 4, False) + divisor_override = args_bounds_check(pool_node.args, 6) + + if ceil_mode is not False: + _LOGGER.debug( + f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}." + ) + return False + + if divisor_override is not None: + _LOGGER.debug( + f"Currently we don't support divisor_override, got divisor_override={divisor_override}." + ) + return False + + return True + + +# Note: AvgPool1d uses avg_pool2d as it converts to 2D first. +@dynamo_tensorrt_converter(torch.ops.aten.avg_pool1d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc] +def aten_ops_avg_pool( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.avg_poolNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + kernel_size=args[1], + stride=args_bounds_check(args, 2, replacement=[]), + padding=args_bounds_check(args, 3, replacement=0), + ceil_mode=args_bounds_check(args, 4, replacement=False), + count_include_pad=args_bounds_check(args, 5, replacement=True), + divisor_override=args_bounds_check(args, 6, replacement=None), + ) + + +def max_pool_param_validator(pool_node: Node) -> bool: + dilation = args_bounds_check(pool_node.args, 4, 1) + ceil_mode = args_bounds_check(pool_node.args, 5, False) + + if dilation != 1: + _LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.") + return False + + if ceil_mode is not False: + _LOGGER.debug( + f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}." + ) + return False + + return True + + +# Note: MaxPool1d uses max_pool2d as it converts to 2D first. +@dynamo_tensorrt_converter(torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator) # type: ignore[misc] +def aten_ops_max_pool( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.max_poolNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + kernel_size=args[1], + stride=args_bounds_check(args, 2, replacement=[]), + padding=args_bounds_check(args, 3, replacement=0), + dilation=args_bounds_check(args, 4, replacement=1), + ceil_mode=args_bounds_check(args, 5, replacement=False), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index e615599eb4..7a49222fba 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -11,6 +11,7 @@ matmul, normalization, permutation, + pool, reduce, select, shape, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py new file mode 100644 index 0000000000..a84402ba89 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -0,0 +1,105 @@ +from typing import Optional, Sequence, Union + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def avg_poolNd( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + kernel_size: Sequence[int], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling." + + if ceil_mode is not False: + raise RuntimeError("ceil_mode is not yet supported!") + + if divisor_override is not None: + raise RuntimeError("divisor_override is not yet supported!") + + dim = len(kernel_size) + + kernel_size = extend_attr_to_tuple(kernel_size, dim) + + if stride == []: + stride = kernel_size + else: + stride = extend_attr_to_tuple(stride, dim) + + padding = extend_attr_to_tuple(padding, dim) + + # add average pooling layer + pool_layer = network.add_pooling_nd( + input=input, + type=trt.PoolingType.AVERAGE, + window_size=kernel_size, + ) + + pool_layer.stride_nd = stride + pool_layer.padding_nd = padding + pool_layer.average_count_excludes_padding = not count_include_pad + + set_layer_name(pool_layer, target, name, source_ir) + return pool_layer.get_output(0) + + +def max_poolNd( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + kernel_size: Sequence[int], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]] = 0, + dilation: Union[int, Sequence[int]] = 1, + ceil_mode: bool = False, +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling." + + if dilation != 1: + raise RuntimeError("dilation is not yet supported!") + + if ceil_mode is not False: + raise RuntimeError("ceil_mode is not yet supported!") + + dim = len(kernel_size) + + kernel_size = extend_attr_to_tuple(kernel_size, dim) + + if stride == []: + stride = kernel_size + else: + stride = extend_attr_to_tuple(stride, dim) + + padding = extend_attr_to_tuple(padding, dim) + + # add max pooling layer + pool_layer = network.add_pooling_nd( + input=input, + type=trt.PoolingType.MAX, + window_size=kernel_size, + ) + + pool_layer.stride_nd = stride + pool_layer.padding_nd = padding + + set_layer_name(pool_layer, target, name, source_ir) + return pool_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_pool_aten.py b/tests/py/dynamo/conversion/test_pool_aten.py new file mode 100644 index 0000000000..4bd6e8ba25 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pool_aten.py @@ -0,0 +1,235 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestPoolConverter(DispatchTestCase): + @parameterized.expand( + [ + (3, 1, 0), + (3, 1, 1), + (2, None, 0), + (4, 1, 1), + (5, 2, 0), + (7, 2, 1), + ] + ) + def test_avg_pool1d( + self, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d( + kernel_size, stride, padding, ceil_mode, count_include_pad + ) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.avg_pool2d.default}, + ) + + @parameterized.expand( + [ + (3, 1, 0), + (3, 1, 1), + ((2, 2), None, (1, 0)), + ((4, 3), (1, 1), (1, 1)), + ((5, 4), (2, 1), (1, 0)), + ((7, 7), (1, 2), (0, 1)), + ] + ) + def test_avg_pool2d( + self, + kernel_size, + stride, + padding, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d( + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.avg_pool2d.default} + ) + + @parameterized.expand( + [ + (3, 1, 0), + (3, 1, 1), + ((2, 2, 3), None, (1, 0, 1)), + ((4, 3, 2), (1, 1, 1), (1, 1, 0)), + ((5, 4, 3), (2, 1, 2), (1, 0, 1)), + ((7, 7, 7), (1, 2, 1), (0, 1, 1)), + ] + ) + def test_avg_pool3d( + self, + kernel_size, + stride, + padding, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d( + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.avg_pool3d.default} + ) + + @parameterized.expand( + [ + (3, 1, 0), + (3, 1, 1), + (2, None, 0), + (4, 1, 1), + (5, 2, 0), + (7, 2, 1), + ] + ) + def test_max_pool1d( + self, + kernel_size, + stride, + padding, + dilation=1, + return_indices=False, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d( + kernel_size, stride, padding, dilation, return_indices, ceil_mode + ) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.max_pool2d}, + ) + + @parameterized.expand( + [ + (3, 1, 0), + (3, 1, 1), + ((2, 2), None, (1, 0)), + ((4, 3), (1, 1), (1, 1)), + ((5, 4), (2, 1), (1, 0)), + ((7, 7), (1, 2), (0, 1)), + ] + ) + def test_max_pool2d( + self, + kernel_size, + stride, + padding, + dilation=1, + return_indices=False, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d( + kernel_size, + stride, + padding, + dilation, + return_indices, + ceil_mode, + ) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) + + @parameterized.expand( + [ + (3, 1, 0), + (3, 1, 1), + ((2, 2, 3), None, (1, 0, 1)), + ((4, 3, 2), (1, 1, 1), (1, 1, 0)), + ((5, 4, 3), (2, 1, 2), (1, 0, 1)), + ((7, 7, 7), (1, 2, 1), (0, 1, 1)), + ] + ) + def test_max_pool3d( + self, + kernel_size, + stride, + padding, + dilation=1, + return_indices=False, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d( + kernel_size, + stride, + padding, + dilation, + return_indices, + ceil_mode, + ) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool3d}) + + +if __name__ == "__main__": + run_tests()