diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9625d1aeae..ede6e5e6a9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -955,18 +955,11 @@ def aten_ops_expand( ) -def amax_param_validator(amax_node: Node) -> bool: - if len(amax_node.args) < 2: - _LOGGER.debug( - f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args." - ) - return False - - return True - - -@dynamo_tensorrt_converter( - torch.ops.aten.amax.default, capability_validator=amax_param_validator +@dynamo_tensorrt_converter(torch.ops.aten.amax.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } ) def aten_ops_amax( ctx: ConversionContext, @@ -986,6 +979,30 @@ def aten_ops_amax( ) +@dynamo_tensorrt_converter(torch.ops.aten.amin.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_amin( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.reduce.amin( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args_bounds_check(args, 1, replacement=[]), + args_bounds_check(args, 2, replacement=False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.sum.default) @dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) @dynamo_tensorrt_converter(torch.ops.prims.sum.default) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 2fcd57a7f6..04f5596581 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -19,7 +19,7 @@ def amax( source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - dim: Union[int, Sequence[int]], + dim: Sequence[int] = [], keepdim: bool = False, ) -> TRTTensor: if (isinstance(input_val, TRTTensor)) and ( @@ -27,7 +27,7 @@ def amax( ): input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) - if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0): + if isinstance(dim, (tuple, list)) and len(dim) == 0: dim = tuple(range(len(input_val.shape))) layer = ctx.net.add_reduce( @@ -40,6 +40,33 @@ def amax( return layer.get_output(0) +def amin( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + dim: Sequence[int] = [], + keepdim: bool = False, +) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) + + if isinstance(dim, (tuple, list)) and len(dim) == 0: + dim = tuple(range(len(input_val.shape))) + + layer = ctx.net.add_reduce( + input_val, + trt.ReduceOperation.MIN, + axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))), + keep_dims=keepdim, + ) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + def sum( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_amin_aten.py b/tests/py/dynamo/conversion/test_amin_aten.py new file mode 100644 index 0000000000..03ae9b6113 --- /dev/null +++ b/tests/py/dynamo/conversion/test_amin_aten.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestAminConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 1, True), + ((2, 3, 4, 5), 3, True), + ((2, 3, 4, 5), 2, False), + ((6, 7, 5, 4, 5), 4, False), + ((1, 5, 2, 1), -1, True), + ] + ) + def test_amin_dim_int_default(self, input_shape, dim, keep_dims): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Amin(), + inputs, + ) + + @parameterized.expand( + [ + ((1, 2, 4), [], True), + ((3, 2, 4), [1], True), + ((2, 1, 4, 5), [0, 3], True), + ((2, 3, 4, 5), [0, 1, 2, 3], False), + ((6, 7, 5, 4, 5), [1, 3, 4], False), + ] + ) + def test_amin_dim_tuple_default(self, input_shape, dim, keep_dims): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Amin(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 2, 4), 1, True, torch.int, 0, 5), + ((2, 3, 4, 5), 3, True, torch.int, -10, 10), + ((2, 3, 4, 5), 2, False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5), + ((1, 5, 2, 1), -4, False, torch.int32, -5, 5), + ] + ) + def test_amin_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amin(), + inputs, + check_dtype=False, + ) + + @parameterized.expand( + [ + ((1, 2, 4), [], True, torch.int, 0, 5), + ((3, 2, 4), [1], True, torch.int, 0, 5), + ((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10), + ((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5), + ((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5), + ] + ) + def test_amin_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amin(), + inputs, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests()