Skip to content

feat: support aten.clamp.Tensor and update aten.clamp.default dynamo converters #2522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,25 +487,6 @@ def aten_ops_softplus(
)


@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
def aten_ops_clip(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.activation.clip(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
alpha=args_bounds_check(args, 1),
beta=args_bounds_check(args, 2),
)


@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
def aten_ops_hard_sigmoid(
ctx: ConversionContext,
Expand Down Expand Up @@ -683,6 +664,9 @@ def aten_ops_where(


@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor)
def aten_ops_clamp(
ctx: ConversionContext,
target: Target,
Expand Down
30 changes: 0 additions & 30 deletions py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,36 +235,6 @@ def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]
)


def clip(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
alpha: float,
beta: float,
) -> TRTTensor:
operation_type = trt.ActivationType.CLIP

def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]:
def clip_fn(x: float) -> float:
return max(alpha, min(beta, x))

return clip_fn(dyn_range[0]), clip_fn(dyn_range[1])

return convert_activation(
ctx,
target,
source_ir,
name,
operation_type,
input_val,
alpha=alpha,
beta=beta,
dyn_range_fn=clip_dyn_range_fn,
)


def hard_sigmoid(
ctx: ConversionContext,
target: Target,
Expand Down
62 changes: 9 additions & 53 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
Expand All @@ -17,7 +16,6 @@
)
from torch_tensorrt.dynamo.conversion.impl.unary import sign
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

Expand Down Expand Up @@ -186,63 +184,21 @@ def clamp(
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
min_val: Optional[Union[int, float, TRTTensor]] = None,
max_val: Optional[Union[int, float, TRTTensor]] = None,
) -> TRTTensor:
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Clamp received input {input_val} that is not part "
"of the TensorRT region!"
)

def _add_layer(
ctx: ConversionContext,
input: TRTTensor,
val: float,
op: trt.ElementWiseOperation,
name: str,
) -> (
trt.ILayer
): # TODO: Simplify and merge implementations, should just be max and min stacked
if not len(input.shape):
# clamping scalar
acc_ops_clamp_trt = get_trt_tensor(
ctx,
squeeze_left(
np.array(
[val],
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
)
),
f"{name}_clamp_{val}",
)
else:
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = np.full(
acc_ops_clamp_shape,
val,
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
)
acc_ops_clamp_trt = ctx.net.add_constant(
acc_ops_clamp_shape, acc_ops_clamp_tensor
).get_output(0)
layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op)
return layer

clamped_val = input_val
if min_val is not None:
clamp_min_layer = _add_layer(
ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name
clamped_val = impl.elementwise.max(
ctx, target, source_ir, f"{name}_max", clamped_val, min_val
)
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
input_val = clamp_min_layer.get_output(0)

if max_val is not None:
clamp_max_layer = _add_layer(
ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name
clamped_val = impl.elementwise.min(
ctx, target, source_ir, f"{name}_min", clamped_val, max_val
)
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
input_val = clamp_max_layer.get_output(0)

return input_val
return clamped_val


def add(
Expand Down
26 changes: 25 additions & 1 deletion tests/py/dynamo/conversion/test_clamp_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(self, x):

class TestScalarModule(torch.nn.Module):
def forward(self, x):
y = torch.ops.aten.mean.default(x)
y = torch.ops.aten.mean.dim(x, None, True)
return torch.ops.aten.clamp.default(y, min, max)

input_specs = [
Expand All @@ -63,6 +63,30 @@ def forward(self, x):
self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)

@parameterized.expand(
[
param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)),
param("min", min=0.5 * torch.randn(3, 4)),
param("max", max=0.5 * torch.randn(3, 4)),
param(
"minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)
),
param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)),
]
)
def test_clamp_tensor(
self,
test_name,
min=None,
max=None,
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.clamp.Tensor(x, min, max)

inputs = [torch.randn(3, 4)]
self.run_test(TestModule(), inputs)


if __name__ == "__main__":
run_tests()
35 changes: 31 additions & 4 deletions tests/py/dynamo/conversion/test_clip_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,38 @@ class TestClipConverter(DispatchTestCase):
def test_clip(self, test_name, min=None, max=None):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.clamp.default(x, min, max)
return torch.ops.aten.clip.default(x, min, max)

inputs = [torch.randn(3, 4)]
self.run_test(TestModule(), inputs)

@parameterized.expand(
[
param(
"defaultInt32",
min=torch.tensor(-1, dtype=torch.int32),
max=torch.tensor(0, dtype=torch.int32),
),
param(
"defaultFloat32",
min=torch.tensor(0.5, dtype=torch.float32),
max=torch.tensor(1.0, dtype=torch.float32),
),
param(
"minBiggerThanMax",
min=torch.tensor(1.0, dtype=torch.float32),
max=torch.tensor(0, dtype=torch.int32),
),
]
)
def test_clip(self, test_name, min=None, max=None):
class TestModule(torch.nn.Module):
def forward(self, x, min, max):
return torch.ops.aten.clip.Tensor(x, min, max)

inputs = [torch.randn(3, 4), min, max]
self.run_test(TestModule(), inputs)

@parameterized.expand(
[
param("default", min=-1, max=0),
Expand All @@ -37,12 +64,12 @@ def test_clip_with_dynamic_shape_four_dimensions(
):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.clamp.default(x, min, max)
return torch.ops.aten.clip.default(x, min, max)

class TestScalarModule(torch.nn.Module):
def forward(self, x):
y = torch.ops.aten.mean.default(x)
return torch.ops.aten.clamp.default(y, min, max)
y = torch.ops.aten.mean.dim(x, None, True)
return torch.ops.aten.clip.default(y, min, max)

input_specs = [
Input(
Expand Down