From de0407287f77462549a52a637261334ba76568bf Mon Sep 17 00:00:00 2001
From: Evan Li <zewenl@nvidia.com>
Date: Thu, 7 Dec 2023 17:07:14 -0800
Subject: [PATCH 1/3] feat: support aten.clamp.Tensor and update
 aten.clamp.default dynamo converters

---
 .../dynamo/conversion/aten_ops_converters.py  |  1 +
 .../dynamo/conversion/impl/elementwise/ops.py | 71 ++++---------------
 tests/py/dynamo/conversion/test_clamp_aten.py | 26 ++++++-
 3 files changed, 40 insertions(+), 58 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 4d9547d3ed..486ab5689b 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -683,6 +683,7 @@ def aten_ops_where(
 
 
 @dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
+@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
 def aten_ops_clamp(
     ctx: ConversionContext,
     target: Target,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
index 06e07eedb1..b30d9a5626 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
@@ -1,6 +1,5 @@
 from typing import Optional, Union
 
-import numpy as np
 import tensorrt as trt
 import torch
 import torch_tensorrt.dynamo.conversion.impl as impl
@@ -17,7 +16,6 @@
 )
 from torch_tensorrt.dynamo.conversion.impl.unary import sign
 from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
-from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
 from torch_tensorrt.fx.types import TRTTensor
 from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
 
@@ -186,63 +184,22 @@ def clamp(
     source_ir: Optional[SourceIR],
     name: str,
     input_val: TRTTensor,
-    min_val: Optional[float] = None,
-    max_val: Optional[float] = None,
+    min_val: Optional[Union[int, float, TRTTensor]] = None,
+    max_val: Optional[Union[int, float, TRTTensor]] = None,
 ) -> TRTTensor:
-    if not isinstance(input_val, TRTTensor):
-        raise RuntimeError(
-            f"Clamp received input {input_val} that is not part "
-            "of the TensorRT region!"
-        )
-
-    def _add_layer(
-        ctx: ConversionContext,
-        input: TRTTensor,
-        val: float,
-        op: trt.ElementWiseOperation,
-        name: str,
-    ) -> (
-        trt.ILayer
-    ):  # TODO: Simplify and merge implementations, should just be max and min stacked
-        if not len(input.shape):
-            # clamping scalar
-            acc_ops_clamp_trt = get_trt_tensor(
-                ctx,
-                squeeze_left(
-                    np.array(
-                        [val],
-                        dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
-                    )
-                ),
-                f"{name}_clamp_{val}",
-            )
-        else:
-            acc_ops_clamp_shape = (1,) * len(input.shape)  # broadcast all dimensions
-            acc_ops_clamp_tensor = np.full(
-                acc_ops_clamp_shape,
-                val,
-                dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
-            )
-            acc_ops_clamp_trt = ctx.net.add_constant(
-                acc_ops_clamp_shape, acc_ops_clamp_tensor
-            ).get_output(0)
-        layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op)
-        return layer
-
-    if min_val is not None:
-        clamp_min_layer = _add_layer(
-            ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name
-        )
-        set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
-        input_val = clamp_min_layer.get_output(0)
-    if max_val is not None:
-        clamp_max_layer = _add_layer(
-            ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name
-        )
-        set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
-        input_val = clamp_max_layer.get_output(0)
+    if min_val is None:
+        min_val = float("-inf")
+    if max_val is None:
+        max_val = float("inf")
 
-    return input_val
+    return impl.elementwise.min(
+        ctx,
+        target,
+        source_ir,
+        f"{name}_min",
+        impl.elementwise.max(ctx, target, source_ir, f"{name}_max", input_val, min_val),
+        max_val,
+    )
 
 
 def add(
diff --git a/tests/py/dynamo/conversion/test_clamp_aten.py b/tests/py/dynamo/conversion/test_clamp_aten.py
index fcee7bfa3c..0bad9ee350 100644
--- a/tests/py/dynamo/conversion/test_clamp_aten.py
+++ b/tests/py/dynamo/conversion/test_clamp_aten.py
@@ -49,7 +49,7 @@ def forward(self, x):
 
         class TestScalarModule(torch.nn.Module):
             def forward(self, x):
-                y = torch.ops.aten.mean.default(x)
+                y = torch.ops.aten.mean.dim(x, None, True)
                 return torch.ops.aten.clamp.default(y, min, max)
 
         input_specs = [
@@ -63,6 +63,30 @@ def forward(self, x):
         self.run_test_with_dynamic_shape(TestModule(), input_specs)
         self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)
 
+    @parameterized.expand(
+        [
+            param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)),
+            param("min", min=0.5 * torch.randn(3, 4)),
+            param("max", max=0.5 * torch.randn(3, 4)),
+            param(
+                "minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)
+            ),
+            param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)),
+        ]
+    )
+    def test_clamp_tensor(
+        self,
+        test_name,
+        min=None,
+        max=None,
+    ):
+        class TestModule(torch.nn.Module):
+            def forward(self, x):
+                return torch.ops.aten.clamp.Tensor(x, min, max)
+
+        inputs = [torch.randn(3, 4)]
+        self.run_test(TestModule(), inputs)
+
 
 if __name__ == "__main__":
     run_tests()

From c3bc1274cf8917bdf00b31799c0b6e228f9e176f Mon Sep 17 00:00:00 2001
From: Evan Li <zewenl@nvidia.com>
Date: Mon, 11 Dec 2023 15:44:36 -0800
Subject: [PATCH 2/3] optimize clamp process

---
 .../dynamo/conversion/impl/elementwise/ops.py | 23 +++++++++----------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
index b30d9a5626..a69fca944b 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
@@ -187,19 +187,18 @@ def clamp(
     min_val: Optional[Union[int, float, TRTTensor]] = None,
     max_val: Optional[Union[int, float, TRTTensor]] = None,
 ) -> TRTTensor:
-    if min_val is None:
-        min_val = float("-inf")
-    if max_val is None:
-        max_val = float("inf")
+    clamped_val = input_val
+    if min_val is not None:
+        clamped_val = impl.elementwise.max(
+            ctx, target, source_ir, f"{name}_max", clamped_val, min_val
+        )
 
-    return impl.elementwise.min(
-        ctx,
-        target,
-        source_ir,
-        f"{name}_min",
-        impl.elementwise.max(ctx, target, source_ir, f"{name}_max", input_val, min_val),
-        max_val,
-    )
+    if max_val is not None:
+        clamped_val = impl.elementwise.min(
+            ctx, target, source_ir, f"{name}_min", clamped_val, max_val
+        )
+
+    return clamped_val
 
 
 def add(

From 3d8c055e63bb554cfe7b7238846dd70e17b8b2ee Mon Sep 17 00:00:00 2001
From: Evan Li <zewenl@nvidia.com>
Date: Mon, 18 Dec 2023 14:45:57 -0800
Subject: [PATCH 3/3] fix bugs where clip is just alias of clamp

---
 .../dynamo/conversion/aten_ops_converters.py  | 21 ++---------
 .../dynamo/conversion/impl/activation/ops.py  | 30 ----------------
 tests/py/dynamo/conversion/test_clip_aten.py  | 35 ++++++++++++++++---
 3 files changed, 33 insertions(+), 53 deletions(-)

diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 486ab5689b..e204152256 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -487,25 +487,6 @@ def aten_ops_softplus(
     )
 
 
-@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
-def aten_ops_clip(
-    ctx: ConversionContext,
-    target: Target,
-    args: Tuple[Argument, ...],
-    kwargs: Dict[str, Argument],
-    name: str,
-) -> Union[TRTTensor, Sequence[TRTTensor]]:
-    return impl.activation.clip(
-        ctx,
-        target,
-        SourceIR.ATEN,
-        name,
-        args[0],
-        alpha=args_bounds_check(args, 1),
-        beta=args_bounds_check(args, 2),
-    )
-
-
 @dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
 def aten_ops_hard_sigmoid(
     ctx: ConversionContext,
@@ -684,6 +665,8 @@ def aten_ops_where(
 
 @dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
 @dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
+@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
+@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor)
 def aten_ops_clamp(
     ctx: ConversionContext,
     target: Target,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
index ac77f790cb..f578351ef2 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
@@ -235,36 +235,6 @@ def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]
     )
 
 
-def clip(
-    ctx: ConversionContext,
-    target: Target,
-    source_ir: Optional[SourceIR],
-    name: str,
-    input_val: TRTTensor,
-    alpha: float,
-    beta: float,
-) -> TRTTensor:
-    operation_type = trt.ActivationType.CLIP
-
-    def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]:
-        def clip_fn(x: float) -> float:
-            return max(alpha, min(beta, x))
-
-        return clip_fn(dyn_range[0]), clip_fn(dyn_range[1])
-
-    return convert_activation(
-        ctx,
-        target,
-        source_ir,
-        name,
-        operation_type,
-        input_val,
-        alpha=alpha,
-        beta=beta,
-        dyn_range_fn=clip_dyn_range_fn,
-    )
-
-
 def hard_sigmoid(
     ctx: ConversionContext,
     target: Target,
diff --git a/tests/py/dynamo/conversion/test_clip_aten.py b/tests/py/dynamo/conversion/test_clip_aten.py
index a3819fb4dd..447e2c9e17 100644
--- a/tests/py/dynamo/conversion/test_clip_aten.py
+++ b/tests/py/dynamo/conversion/test_clip_aten.py
@@ -19,11 +19,38 @@ class TestClipConverter(DispatchTestCase):
     def test_clip(self, test_name, min=None, max=None):
         class TestModule(torch.nn.Module):
             def forward(self, x):
-                return torch.ops.aten.clamp.default(x, min, max)
+                return torch.ops.aten.clip.default(x, min, max)
 
         inputs = [torch.randn(3, 4)]
         self.run_test(TestModule(), inputs)
 
+    @parameterized.expand(
+        [
+            param(
+                "defaultInt32",
+                min=torch.tensor(-1, dtype=torch.int32),
+                max=torch.tensor(0, dtype=torch.int32),
+            ),
+            param(
+                "defaultFloat32",
+                min=torch.tensor(0.5, dtype=torch.float32),
+                max=torch.tensor(1.0, dtype=torch.float32),
+            ),
+            param(
+                "minBiggerThanMax",
+                min=torch.tensor(1.0, dtype=torch.float32),
+                max=torch.tensor(0, dtype=torch.int32),
+            ),
+        ]
+    )
+    def test_clip(self, test_name, min=None, max=None):
+        class TestModule(torch.nn.Module):
+            def forward(self, x, min, max):
+                return torch.ops.aten.clip.Tensor(x, min, max)
+
+        inputs = [torch.randn(3, 4), min, max]
+        self.run_test(TestModule(), inputs)
+
     @parameterized.expand(
         [
             param("default", min=-1, max=0),
@@ -37,12 +64,12 @@ def test_clip_with_dynamic_shape_four_dimensions(
     ):
         class TestModule(torch.nn.Module):
             def forward(self, x):
-                return torch.ops.aten.clamp.default(x, min, max)
+                return torch.ops.aten.clip.default(x, min, max)
 
         class TestScalarModule(torch.nn.Module):
             def forward(self, x):
-                y = torch.ops.aten.mean.default(x)
-                return torch.ops.aten.clamp.default(y, min, max)
+                y = torch.ops.aten.mean.dim(x, None, True)
+                return torch.ops.aten.clip.default(y, min, max)
 
         input_specs = [
             Input(