Skip to content

Commit b4da15e

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and where operator
adding where aten op aten::where correction and linting error changes
1 parent 7551eee commit b4da15e

File tree

6 files changed

+237
-1
lines changed

6 files changed

+237
-1
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+20
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch_tensorrt.fx.converters.impl.select import select
3333
from torch_tensorrt.fx.converters.impl.slice import slice_op
3434
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
35+
from torch_tensorrt.fx.converters.impl.condition import where
3536

3637
_LOGGER: logging.Logger = logging.getLogger(__name__)
3738

@@ -563,6 +564,25 @@ def aten_ops_tanh(
563564
)
564565

565566

567+
@tensorrt_converter(torch.ops.aten.where.self)
568+
def aten_ops_where(
569+
network: TRTNetwork,
570+
target: Target,
571+
args: Tuple[Argument, ...],
572+
kwargs: Dict[str, Argument],
573+
name: str,
574+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
575+
return where(
576+
network,
577+
target,
578+
SourceIR.ATEN,
579+
name,
580+
args[1],
581+
args[2],
582+
args[0],
583+
)
584+
585+
566586
@tensorrt_converter(torch.ops.aten.cat.default)
567587
def aten_ops_cat(
568588
network: TRTNetwork,

py/torch_tensorrt/fx/converters/converter_utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,37 @@ def broadcast(
409409
return a, b
410410

411411

412+
def broadcastable(
413+
a: TRTTensor,
414+
b: TRTTensor,
415+
) -> bool:
416+
"Check if two tensors are broadcastable according to torch rules"
417+
a_shape = tuple(a.shape)
418+
b_shape = tuple(b.shape)
419+
# check from the trailing
420+
diff = len(a_shape) - len(b_shape)
421+
if diff == 0:
422+
return True
423+
if diff > 0:
424+
max = len(a_shape)
425+
min = len(b_shape)
426+
greater_tensor = a_shape
427+
lesser_tensor = b_shape
428+
elif diff < 0:
429+
max = len(b_shape)
430+
min = len(a_shape)
431+
greater_tensor = b_shape
432+
lesser_tensor = a_shape
433+
j = min - 1
434+
for i in range(max - 1, diff - 1, -1):
435+
if not (
436+
greater_tensor[i] != lesser_tensor[j]
437+
and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
438+
):
439+
return False
440+
return True
441+
442+
412443
def squeeze_left(const: Union[torch.Tensor, np.ndarray]):
413444
"""
414445
Squeeze the size-1 dimensions on the left side of the shape tuple.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast
4+
5+
import numpy as np
6+
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Target
10+
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
12+
from torch_tensorrt.fx.converters.converter_utils import (
13+
SourceIR,
14+
broadcast,
15+
broadcastable,
16+
get_trt_tensor,
17+
set_layer_name,
18+
)
19+
from torch_tensorrt.fx.converters.impl.slice import expand
20+
21+
22+
def where(
23+
network: TRTNetwork,
24+
target: Target,
25+
source_ir: Optional[SourceIR],
26+
name: str,
27+
input: TRTTensor,
28+
other: TRTTensor,
29+
condition: TRTTensor,
30+
) -> TRTTensor:
31+
input_dim = len(tuple(input.shape))
32+
other_dim = len(tuple(other.shape))
33+
condition_dim = len(tuple(condition.shape))
34+
35+
if type(input) != TRTTensor:
36+
assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!"
37+
38+
if type(other) != TRTTensor:
39+
assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!"
40+
41+
if not (broadcastable(input, other)):
42+
assert f"The two torch tensors should be broadcastable"
43+
44+
# get output shape
45+
# purpose of this is to bring input and other rank same as
46+
# output_shape to input it to the add_expand operation
47+
# condition will have dimension of either input or other
48+
input, other = broadcast(network, input, other, f"{name}_x", f"{name}_y")
49+
if len(tuple(condition.shape)) != len(tuple(input.shape)):
50+
condition, input = broadcast(
51+
network, condition, input, f"{name}_condition", f"{name}_x"
52+
)
53+
54+
x_shape = list(input.shape)
55+
y_shape = list(other.shape)
56+
condition_shape = list(condition.shape)
57+
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
58+
59+
# expand shape
60+
if type(condition) != TRTTensor:
61+
assert condition.dtype == torch.bool, "condition dtype is not bool"
62+
if condition_shape != output_shape:
63+
condition.expand(output_shape)
64+
condition = condition.to(torch.int32)
65+
condition_const = get_trt_tensor(network, condition, f"{name}_condition")
66+
condition_layer = network.add_identity(condition_const)
67+
condition_layer.set_output_type(0, trt.bool)
68+
set_layer_name(condition_layer, target, f"{name}_condition")
69+
condition_val = condition_layer.get_output(0)
70+
else:
71+
assert condition.dtype == trt.bool, "mask dtype is not bool!"
72+
if condition_shape != condition_dim:
73+
condition_val = expand(
74+
network, target, source_ir, f"{name}_expand", condition, output_shape
75+
)
76+
else:
77+
condition_val = condition
78+
79+
if type(input) != TRTTensor:
80+
if x_shape != input_dim:
81+
# special case where 1 element in input
82+
if len(input.shape) == 0:
83+
input = input.unsqueeze(0)
84+
input = input.expand(output_shape)
85+
x_val = get_trt_tensor(network, input, f"{name}_x")
86+
else:
87+
x_val = input
88+
if x_shape != output_shape:
89+
x_val = expand(
90+
network, target, source_ir, f"{name}_x_expand", input, output_shape
91+
)
92+
93+
if type(other) != TRTTensor:
94+
if y_shape != output_shape:
95+
# special case where 1 element in other
96+
if len(other.shape) == 0:
97+
other = other.unsqueeze(0)
98+
other = other.expand(output_shape)
99+
y_val = get_trt_tensor(network, other, f"{name}_y")
100+
else:
101+
y_val = other
102+
if y_shape != other_dim:
103+
y_val = expand(
104+
network, target, source_ir, f"{name}_y_expand", y_val, output_shape
105+
)
106+
107+
select_layer = network.add_select(condition_val, x_val, y_val)
108+
109+
set_layer_name(select_layer, target, f"{name}_select")
110+
111+
return select_layer.get_output(0)

py/torch_tensorrt/fx/converters/impl/slice/ops.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
import torch
1111
from torch.fx.node import Target
1212

13-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
13+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
14+
1415
from torch_tensorrt.fx.converters.converter_utils import (
1516
SourceIR,
1617
set_layer_name,
1718
get_positive_dim,
1819
has_dynamic_shape,
20+
broadcast,
21+
get_trt_tensor,
1922
)
2023
from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape
2124
from torch_tensorrt.fx.converters.impl.slice.base import slice
@@ -62,3 +65,40 @@ def slice_op(
6265
output_shape[dim] = math.ceil((stop_int - start_int) / step_int)
6366

6467
return slice(network, target, source_ir, name, input, start, output_shape, stride)
68+
69+
70+
def expand(
71+
network: TRTNetwork,
72+
target: Target,
73+
source_ir: Optional[SourceIR],
74+
name: str,
75+
input: TRTTensor,
76+
sizes: Shape,
77+
) -> TRTTensor:
78+
shape = list(sizes)
79+
80+
input_val = get_trt_tensor(network, input, f"{name}_input")
81+
82+
if network.has_implicit_batch_dimension:
83+
shape = shape[1:]
84+
85+
ranks = len(input_val.shape)
86+
# TRT does not support different dimension size
87+
# though this condition is not seen in the case of bmm
88+
# where input_t and shape dimensions are not equal
89+
assert len(shape) >= ranks
90+
if len(shape) != ranks:
91+
shape_tuple = tuple([0] * len(shape))
92+
shape_tensor = get_trt_tensor(network, input, f"{name}_shape")
93+
input_val, shape_tensor = broadcast(
94+
network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val"
95+
)
96+
ranks = len(shape)
97+
98+
inshape = tuple(input_val.shape)
99+
shape = tuple(shape)
100+
start = tuple([0] * ranks)
101+
stride = tuple(
102+
[int(i == o) for i, o in zip(inshape, shape)]
103+
) # stride == 1 if dimensions match, 0 otherwise
104+
return slice(network, target, source_ir, name, input_val, start, shape, stride)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
6+
7+
8+
class TestWhereConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
("2d_condition_xshape_yshape", (2, 2), (2, 2)),
12+
("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)),
13+
("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)),
14+
("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)),
15+
]
16+
)
17+
def test_(self, _, x_size, y_size):
18+
class Where(nn.Module):
19+
def forward(self, condition, x, y):
20+
return torch.where(condition, x, y)
21+
22+
inputX = torch.randn(*x_size)
23+
inputOther = torch.randn(*y_size)
24+
condition = inputX < 0
25+
self.run_test(
26+
Where(),
27+
(condition, inputX, inputOther),
28+
expected_ops={torch.ops.aten.where.self},
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()

0 commit comments

Comments
 (0)