Skip to content

Commit 31cd039

Browse files
committed
fix/feat: Add and repair multiple converters
- Focus on SD-performance-accelerating converters - Add test cases for converters to avoid regressions
1 parent 0d402fb commit 31cd039

File tree

7 files changed

+179
-26
lines changed

7 files changed

+179
-26
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+32
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,7 @@ def aten_ops_sub(
11561156
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
11571157
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
11581158
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
1159+
@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc]
11591160
def aten_ops_div(
11601161
network: TRTNetwork,
11611162
target: Target,
@@ -1413,3 +1414,34 @@ def aten_ops_linear(
14131414
weight=args[1],
14141415
bias=args_bounds_check(args, 2, None),
14151416
)
1417+
1418+
1419+
# TODO: expand the scope of this converter with aten.expand implementation
1420+
def broadcast_checker(broadcast_node: torch.fx.Node) -> bool:
1421+
# The current implementation of broadcast_in_dim can only handle unsqueeze
1422+
return all(
1423+
broadcast_node.args[1][i] == 1
1424+
for i in range(len(broadcast_node.args[1]))
1425+
if i not in broadcast_node.args[2]
1426+
)
1427+
1428+
1429+
@dynamo_tensorrt_converter(
1430+
torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker
1431+
) # type: ignore[misc]
1432+
def aten_ops_broadcast_in_dim(
1433+
network: TRTNetwork,
1434+
target: Target,
1435+
args: Tuple[Argument, ...],
1436+
kwargs: Dict[str, Argument],
1437+
name: str,
1438+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1439+
return impl.unsqueeze.broadcast_in_dim(
1440+
network,
1441+
target,
1442+
SourceIR.ATEN,
1443+
name,
1444+
args[0],
1445+
args[1],
1446+
args[2],
1447+
)

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import numpy as np
34
import tensorrt as trt
45
import torch
56
from torch.fx.node import Target
@@ -22,16 +23,6 @@ def where(
2223
other: TRTTensor,
2324
condition: TRTTensor,
2425
) -> TRTTensor:
25-
input_dim = len(tuple(input.shape))
26-
other_dim = len(tuple(other.shape))
27-
condition_dim = len(tuple(condition.shape))
28-
29-
if type(input) != TRTTensor:
30-
assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!"
31-
32-
if type(other) != TRTTensor:
33-
assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!"
34-
3526
if not (broadcastable(input, other)):
3627
assert "The two torch tensors should be broadcastable"
3728

@@ -48,33 +39,37 @@ def where(
4839
x_shape = list(input.shape)
4940
y_shape = list(other.shape)
5041
condition_shape = list(condition.shape)
42+
5143
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
5244

5345
# expand shape
54-
if type(condition) != TRTTensor:
55-
assert condition.dtype == torch.bool, "condition dtype is not bool"
46+
if not isinstance(condition, TRTTensor):
47+
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
5648
if condition_shape != output_shape:
57-
condition.expand(output_shape)
58-
condition = condition.to(torch.int32)
59-
condition_const = get_trt_tensor(network, condition, f"{name}_condition")
60-
condition_layer = network.add_identity(condition_const)
61-
condition_layer.set_output_type(0, trt.bool)
62-
set_layer_name(condition_layer, target, f"{name}_condition")
63-
condition_val = condition_layer.get_output(0)
49+
condition = (
50+
condition.expand(output_shape)
51+
if isinstance(condition, torch.Tensor)
52+
else np.broadcast_to(condition, output_shape)
53+
)
54+
condition_val = get_trt_tensor(network, condition, f"{name}_condition")
6455
else:
6556
assert condition.dtype == trt.bool, "mask dtype is not bool!"
66-
if len(condition_shape) != condition_dim:
57+
if condition_shape != output_shape:
6758
condition_val = expand(
6859
network, target, source_ir, f"{name}_expand", condition, output_shape
6960
)
7061
else:
7162
condition_val = condition
7263

73-
if type(input) != TRTTensor:
64+
if not isinstance(input, TRTTensor):
7465
if x_shape != output_shape:
7566
# special case where 1 element in input
7667
if len(input.shape) == 0:
77-
input = input.unsqueeze(0)
68+
input = (
69+
input.unsqueeze(0)
70+
if isinstance(input, torch.Tensor)
71+
else np.expand_dims(input, axis=0)
72+
)
7873
input = input.expand(output_shape)
7974
x_val = get_trt_tensor(network, input, f"{name}_x")
8075
else:
@@ -84,11 +79,15 @@ def where(
8479
network, target, source_ir, f"{name}_x_expand", input, output_shape
8580
)
8681

87-
if type(other) != TRTTensor:
82+
if not isinstance(other, TRTTensor):
8883
if y_shape != output_shape:
8984
# special case where 1 element in other
9085
if len(other.shape) == 0:
91-
other = other.unsqueeze(0)
86+
other = (
87+
other.unsqueeze(0)
88+
if isinstance(other, torch.Tensor)
89+
else np.expand_dims(other, axis=0)
90+
)
9291
other = other.expand(output_shape)
9392
y_val = get_trt_tensor(network, other, f"{name}_y")
9493
else:

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def sum(
4949
):
5050
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
5151

52-
if dim is None:
52+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5353
dim = tuple(range(len(input_val.shape)))
5454
layer = network.add_reduce(
5555
input_val,

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, cast
1+
from typing import List, Optional, Sequence, cast
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -48,3 +48,42 @@ def unsqueeze(
4848
)
4949
set_layer_name(layer, target, name, source_ir)
5050
return layer.get_output(0)
51+
52+
53+
def broadcast_in_dim(
54+
network: TRTNetwork,
55+
target: Target,
56+
source_ir: Optional[SourceIR],
57+
name: str,
58+
input_t: TRTTensor,
59+
shape: Sequence[int],
60+
broadcast_dimensions: Sequence[int],
61+
) -> TRTTensor:
62+
augmented_shape_list: List[Optional[int]] = list(shape)
63+
64+
# For each dimension being broadcasted, set the augmented shape to None
65+
for broadcast_dim in broadcast_dimensions:
66+
augmented_shape_list[broadcast_dim] = None
67+
68+
# TODO: Expand support to arbitrary broadcasts
69+
assert all(
70+
dim in (1, None) for dim in augmented_shape_list
71+
), "broadcast_in_dim currently only supports unsqueeze broadcasting"
72+
73+
# Unsqueeze the shape repeatedly to broadcast
74+
output = input_t
75+
for idx, x in enumerate(augmented_shape_list):
76+
# If the value is not None, that dimension is to be broadcasted
77+
if x is not None:
78+
output = unsqueeze(
79+
network,
80+
target,
81+
source_ir,
82+
name + f"_unsqueeze_for_broadcast_{idx}",
83+
output,
84+
idx,
85+
)
86+
87+
assert tuple(output.shape) == tuple(shape), "broadcast_in_dim shapes don't match"
88+
89+
return output

tests/py/dynamo/conversion/test_div_aten.py

+18
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ def forward(self, lhs_val):
8282
expected_ops={torch.ops.aten.div.Tensor_mode},
8383
)
8484

85+
@parameterized.expand(
86+
[
87+
("2d", (2, 1)),
88+
("3d", (2, 1, 2)),
89+
]
90+
)
91+
def test_prims_div_tensor(self, _, shape):
92+
class div(nn.Module):
93+
def forward(self, lhs_val, rhs_val):
94+
return torch.ops.prims.div.default(lhs_val, rhs_val)
95+
96+
inputs = [torch.randn(shape), torch.randn(shape)]
97+
self.run_test(
98+
div(),
99+
inputs,
100+
expected_ops={torch.ops.prims.div.default},
101+
)
102+
85103

86104
if __name__ == "__main__":
87105
run_tests()

tests/py/dynamo/conversion/test_unsqueeze_aten.py

+47
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
66
from torch_tensorrt import Input
7+
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
78

89
from .harness import DispatchTestCase
910

@@ -59,5 +60,51 @@ def forward(self, x):
5960
)
6061

6162

63+
class TestBroadcastInDim(DispatchTestCase):
64+
def test_broadcast_in_dim_supported(
65+
self,
66+
):
67+
class Unsqueeze(nn.Module):
68+
def forward(self, x):
69+
return torch.ops.prims.broadcast_in_dim.default(
70+
x, [4, 5, 6, 1, 1], [0, 1, 2]
71+
)
72+
73+
inputs = [torch.randn(4, 5, 6)]
74+
self.run_test(
75+
Unsqueeze(), inputs, expected_ops={torch.ops.prims.broadcast_in_dim.default}
76+
)
77+
78+
def test_broadcast_in_dim_supported_singleton(
79+
self,
80+
):
81+
class Unsqueeze(nn.Module):
82+
def forward(self, x):
83+
return torch.ops.prims.broadcast_in_dim.default(x, [1, 1, 1], [0, 1])
84+
85+
inputs = [torch.randn(1, 1)]
86+
self.run_test(
87+
Unsqueeze(), inputs, expected_ops={torch.ops.prims.broadcast_in_dim.default}
88+
)
89+
90+
# TODO: Remove this test when support is updated
91+
def test_broadcast_in_dim_unsupported(
92+
self,
93+
):
94+
class Unsqueeze(nn.Module):
95+
def forward(self, x):
96+
return torch.ops.prims.broadcast_in_dim.default(
97+
x, [4, 5, 6, 7, 1], [0, 1, 2]
98+
)
99+
100+
inputs = [torch.randn(4, 5, 6)]
101+
with self.assertRaises(UnsupportedOperatorException):
102+
self.run_test(
103+
Unsqueeze(),
104+
inputs,
105+
expected_ops={torch.ops.prims.broadcast_in_dim.default},
106+
)
107+
108+
62109
if __name__ == "__main__":
63110
run_tests()

tests/py/dynamo/conversion/test_where_aten.py

+18
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ def forward(self, condition, x, y):
4444
expected_ops={torch.ops.aten.where.self},
4545
)
4646

47+
def test_const_input(self):
48+
class Where(nn.Module):
49+
def __init__(self, *args, **kwargs) -> None:
50+
super().__init__(*args, **kwargs)
51+
self.inputY = torch.randn((5, 6, 7))
52+
self.inputX = torch.randn((5, 6, 7))
53+
54+
def forward(self, condition):
55+
return torch.where(condition, self.inputX, self.inputY)
56+
57+
input1 = torch.randn((5, 6, 7))
58+
condition = input1 < 0
59+
self.run_test(
60+
Where(),
61+
(condition,),
62+
expected_ops={torch.ops.aten.where.self},
63+
)
64+
4765

4866
if __name__ == "__main__":
4967
run_tests()

0 commit comments

Comments
 (0)