diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 478cf98dea..7f0e3505b5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2184,6 +2184,24 @@ def aten_ops_avg_pool( ) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default) +def aten_ops_adaptive_avg_pool( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.adaptive_avg_pool1d( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + output_size=args[1], + ) + + def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 13c8645a90..8c16f59030 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -1,6 +1,8 @@ -from typing import Optional, Sequence, Union +import math +from typing import Dict, Optional, Sequence, Union import tensorrt as trt +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -104,3 +106,66 @@ def max_poolNd( set_layer_name(pool_layer, target, name, source_ir) return pool_layer.get_output(0) + + +def adaptive_avg_pool1d( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + output_size: Union[int, Sequence[int]], +) -> TRTTensor: + def start_index(idx: int, out_dim: int, in_dim: int) -> int: + """Calculate the start index of each pooling window""" + return math.floor((float(idx) * float(in_dim)) / out_dim) + + def end_index(idx: int, out_dim: int, in_dim: int) -> int: + """Calculate the end index of each pooling window""" + return math.ceil((float(idx + 1) * float(in_dim)) / out_dim) + + in_dim = input.shape[-1] + out_dim = output_size if isinstance(output_size, int) else output_size[0] + output_list = [] + + # store {index: slice} for reducing repeated slice ops + idx_slice_map: Dict[int, TRTTensor] = {} + # iterate over each output dimension + for i in range(out_dim): + # calculate the start and end index of each pooling window + start = start_index(i, out_dim, in_dim) + end = end_index(i, out_dim, in_dim) + + # slice the input tensor from start to end index, the result of which is the window waiting for pooling + slices = [] + for j in range(start, end): + if j in idx_slice_map: + slice = idx_slice_map[j] + else: + slice = impl.select.select( + ctx, target, source_ir, f"{name}_select_{j}", input, -1, j + ) + slice = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_{i}_{j}", + slice, + (*slice.shape, 1), + ) + idx_slice_map[j] = slice + + slices.append(slice) + + slices = impl.cat.cat( + ctx, target, source_ir, f"{name}_slices_cat_{i}", slices, dim=-1 + ) + # calculate the mean of the slices (average pooling output) and append to the output list + output_list.append( + impl.reduce.mean( + ctx, target, source_ir, f"{name}_sum_{i}", slices, dim=-1, keepdim=True + ) + ) + + output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1) + return output diff --git a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py index e19e1b6187..3d48409631 100644 --- a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py +++ b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py @@ -9,102 +9,77 @@ class TestAdaptiveAvgPoolConverter(DispatchTestCase): @parameterized.expand( [ - ((64, 64),), - ((128, 64),), - # (64,), This case has been there in previous code but it isn't a valid pytorch code. - ] - ) - def test_adaptive_avgpool( - self, - output_size, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d(output_size) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 256, 256)] - self.run_test( - TestModule(), - inputs, - use_dynamo_tracer=True, - ) - - def test_adaptive_avgpool_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - Input( - shape=(-1, -1, 256, 256), - dtype=torch.float32, - shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ( + (2, 3), + 2, + ), + ( + (2, 8), + 8, + ), + ( + (1, 2, 3), + 2, + ), + ( + (2, 2, 8), + 16, + ), + ( + (2, 3), + (1,), + ), + ( + (2, 3), + (2,), + ), + ( + (2, 8), + (4,), + ), + ( + (2, 8), + (16,), + ), + ( + (2, 3, 1), + (1,), + ), + ( + (2, 3, 2), + (2,), + ), + ( + (2, 3, 4), + (4,), + ), + ( + (2, 2, 32), + (31,), + ), + ( + (2, 2, 32), + (64,), ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, use_dynamo_tracer=True - ) - - @parameterized.expand( - [ - ((16, 16, 16),), - ((32, 16, 4),), - (32,), ] ) - def test_adaptive_avgpool3d( + def test_adaptive_avg_pool1d( self, + input_shape, output_size, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d(output_size) - def forward(self, x): - return self.pool(x) + return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size) - inputs = [torch.randn(1, 3, 32, 64, 64)] + inputs = [torch.randn(input_shape)] self.run_test( TestModule(), inputs, - use_dynamo_tracer=True, + # use_dynamo_tracer=True, + enable_passes=True, ) - def test_adaptive_avgpool3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - Input( - shape=(-1, -1, 32, 64, 64), - dtype=torch.float32, - shape_ranges=[ - ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) - ], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - use_dynamo_tracer=True, - ) - - # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." - if __name__ == "__main__": run_tests()