Skip to content

Commit df401dd

Browse files
borisfomgs-olive
authored andcommitted
Moved clamp to impl
Signed-off-by: Boris Fomitchev <[email protected]> fixed method name Signed-off-by: Boris Fomitchev <[email protected]>
1 parent b424735 commit df401dd

File tree

4 files changed

+180
-56
lines changed

4 files changed

+180
-56
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 10 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
3333
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
3434
from torch_tensorrt.fx.converters.impl.normalization import softmax
35+
from torch_tensorrt.fx.converters.impl.elementwise import clamp
3536
from torch_tensorrt.fx.converters.impl.unary import sign
3637
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3738
convert_binary_elementwise,
@@ -2818,38 +2819,6 @@ def acc_ops_linear(
28182819
return res
28192820

28202821

2821-
def add_clamp(network, input, val, op, name):
2822-
if not len(input.shape):
2823-
# clamping scalar
2824-
acc_ops_clamp_trt = get_trt_tensor(
2825-
network,
2826-
squeeze_left(
2827-
torch.tensor(
2828-
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
2829-
)
2830-
),
2831-
f"{name}_clamp_{val}",
2832-
)
2833-
else:
2834-
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
2835-
acc_ops_clamp_tensor = (
2836-
(
2837-
val
2838-
* torch.ones(
2839-
acc_ops_clamp_shape,
2840-
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
2841-
)
2842-
)
2843-
.cpu()
2844-
.numpy()
2845-
)
2846-
acc_ops_clamp_trt = network.add_constant(
2847-
acc_ops_clamp_shape, acc_ops_clamp_tensor
2848-
).get_output(0)
2849-
layer = network.add_elementwise(input, acc_ops_clamp_trt, op)
2850-
return layer
2851-
2852-
28532822
@tensorrt_converter(acc_ops.clamp)
28542823
def acc_ops_clamp(
28552824
network: TRTNetwork,
@@ -2858,30 +2827,15 @@ def acc_ops_clamp(
28582827
kwargs: Dict[str, Argument],
28592828
name: str,
28602829
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2861-
input_val = kwargs["input"]
2862-
min_val = kwargs["min"]
2863-
max_val = kwargs["max"]
2864-
2865-
if not isinstance(input_val, TRTTensor):
2866-
raise RuntimeError(
2867-
f"Clamp received input {input_val} that is not part "
2868-
"of the TensorRT region!"
2869-
)
2870-
2871-
if min_val is not None:
2872-
clamp_min_layer = add_clamp(
2873-
network, input_val, min_val, trt.ElementWiseOperation.MAX, name
2874-
)
2875-
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
2876-
input_val = clamp_min_layer.get_output(0)
2877-
if max_val is not None:
2878-
clamp_max_layer = add_clamp(
2879-
network, input_val, max_val, trt.ElementWiseOperation.MIN, name
2880-
)
2881-
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
2882-
input_val = clamp_max_layer.get_output(0)
2883-
2884-
return input_val
2830+
return clamp.clamp(
2831+
network,
2832+
target,
2833+
SourceIR.ACC,
2834+
name,
2835+
input_val=kwargs["input"],
2836+
min_val=kwargs["min"],
2837+
max_val=kwargs["max"],
2838+
)
28852839

28862840

28872841
@tensorrt_converter(acc_ops.tuple_construct)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,15 @@
3434
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
3535
from torch_tensorrt.fx.converters.impl.condition import where
3636
from torch_tensorrt.fx.converters.impl.unsqueeze import unsqueeze
37+
from torch_tensorrt.fx.converters.impl.elementwise import clamp
3738

3839
_LOGGER: logging.Logger = logging.getLogger(__name__)
3940

41+
42+
def or_none(args, i):
43+
return args[i] if len(args) > i else None
44+
45+
4046
## converter list in alphabetic order
4147
@tensorrt_converter(torch.ops.aten.add.Tensor)
4248
def aten_ops_add(
@@ -610,6 +616,25 @@ def aten_ops_cat(
610616
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
611617

612618

619+
@tensorrt_converter(torch.ops.aten.clamp.default)
620+
def aten_ops_clamp(
621+
network: TRTNetwork,
622+
target: Target,
623+
args: Tuple[Argument, ...],
624+
kwargs: Dict[str, Argument],
625+
name: str,
626+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
627+
return clamp.clamp(
628+
network,
629+
target,
630+
SourceIR.ACC,
631+
name,
632+
input_val=args[0],
633+
min_val=or_none(args, 1),
634+
max_val=or_none(args, 2),
635+
)
636+
637+
613638
@tensorrt_converter(torch.ops.aten.expand.default)
614639
def aten_ops_expand(
615640
network: TRTNetwork,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
import operator
3+
import warnings
4+
from typing import cast, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5+
6+
# @manual=//deeplearning/trt/python:py_tensorrt
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Argument, Target
10+
11+
from ...converter_utils import * # noqa: F403
12+
from ....utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
13+
14+
from torch_tensorrt.fx.types import (
15+
TRTNetwork,
16+
TRTTensor,
17+
)
18+
19+
20+
def add_clamp(network, input, val, op, name):
21+
if not len(input.shape):
22+
# clamping scalar
23+
acc_ops_clamp_trt = get_trt_tensor(
24+
network,
25+
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
26+
f"{name}_clamp_{val}",
27+
)
28+
else:
29+
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
30+
acc_ops_clamp_tensor = (
31+
(
32+
val
33+
* torch.ones(
34+
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
35+
)
36+
)
37+
.cpu()
38+
.numpy()
39+
)
40+
acc_ops_clamp_trt = network.add_constant(
41+
acc_ops_clamp_shape, acc_ops_clamp_tensor
42+
).get_output(0)
43+
layer = network.add_elementwise(input, acc_ops_clamp_trt, op)
44+
return layer
45+
46+
47+
def clamp(
48+
network: TRTNetwork,
49+
target: Target,
50+
source_ir: Optional[SourceIR],
51+
name: str,
52+
input_val,
53+
min_val=None,
54+
max_val=None,
55+
) -> TRTTensor:
56+
if not isinstance(input_val, TRTTensor):
57+
raise RuntimeError(
58+
f"Clamp received input {input_val} that is not part "
59+
"of the TensorRT region!"
60+
)
61+
62+
if min_val is not None:
63+
clamp_min_layer = add_clamp(
64+
network, input_val, min_val, trt.ElementWiseOperation.MAX, name
65+
)
66+
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
67+
input_val = clamp_min_layer.get_output(0)
68+
if max_val is not None:
69+
clamp_max_layer = add_clamp(
70+
network, input_val, max_val, trt.ElementWiseOperation.MIN, name
71+
)
72+
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
73+
input_val = clamp_max_layer.get_output(0)
74+
75+
return input_val
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from parameterized import param, parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestClampConverter(DispatchTestCase):
8+
@parameterized.expand(
9+
[
10+
param("default", min=-1, max=0),
11+
param("min", min=0.5),
12+
param("max", max=0.5),
13+
param("minBiggerThanMax", min=1, max=0),
14+
param("float32Boundary", min=-3.4028234663852886e38),
15+
]
16+
)
17+
def test_clamp(
18+
self,
19+
test_name,
20+
min=None,
21+
max=None,
22+
):
23+
class TestModule(torch.nn.Module):
24+
def forward(self, x):
25+
return torch.clamp(x, min, max)
26+
27+
inputs = [torch.randn(3, 4)]
28+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.clamp.default})
29+
30+
@parameterized.expand(
31+
[
32+
param("default", min=-1, max=0),
33+
param("min", min=0.5),
34+
param("max", max=0.5),
35+
param("minBiggerThanMax", min=1, max=0),
36+
]
37+
)
38+
def test_clamp_with_dynamic_shape_four_dimensions(
39+
self,
40+
test_name,
41+
min=None,
42+
max=None,
43+
):
44+
class TestModule(torch.nn.Module):
45+
def forward(self, x):
46+
return torch.clamp(x, min, max)
47+
48+
class TestScalarModule(torch.nn.Module):
49+
def forward(self, x):
50+
y = torch.mean(x)
51+
return torch.clamp(y, min, max)
52+
53+
input_specs = [
54+
InputTensorSpec(
55+
shape=(-1, -1, 3, 3),
56+
dtype=torch.float32,
57+
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
58+
),
59+
]
60+
61+
self.run_test_with_dynamic_shape(
62+
TestModule(), input_specs, expected_ops={torch.ops.aten.clamp.default}
63+
)
64+
self.run_test_with_dynamic_shape(
65+
TestScalarModule(), input_specs, expected_ops={torch.ops.aten.clamp.default}
66+
)
67+
68+
69+
if __name__ == "__main__":
70+
run_tests()

0 commit comments

Comments
 (0)