Skip to content

Commit 2cc80d8

Browse files
authored
feat: support aten.diagonal converter (#2856)
1 parent 1e272b4 commit 2cc80d8

File tree

4 files changed

+276
-0
lines changed

4 files changed

+276
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3036,6 +3036,64 @@ def aten_ops_flip(
30363036
)
30373037

30383038

3039+
def zero_diag_size_validator(node: Node) -> bool:
3040+
meta = node.args[0].meta.get("tensor_meta")
3041+
if meta:
3042+
input_shape = meta.shape
3043+
else:
3044+
_LOGGER.warning(
3045+
"Meta information of input is missing. Unable to validate diagonal size, falling back to PyTorch operation."
3046+
)
3047+
return False
3048+
3049+
offset, dim1, dim2 = (
3050+
node.args[1],
3051+
node.args[2],
3052+
node.args[3],
3053+
)
3054+
3055+
num_dims = len(input_shape)
3056+
3057+
# Adjust dimensions to be positive and canonicalize
3058+
dim1 = get_positive_dim(dim1, num_dims)
3059+
dim2 = get_positive_dim(dim2, num_dims)
3060+
3061+
if offset >= 0:
3062+
diag_size = max(min(input_shape[dim1], input_shape[dim2] - offset), 0)
3063+
else:
3064+
diag_size = max(min(input_shape[dim1] + offset, input_shape[dim2]), 0)
3065+
3066+
if diag_size == 0:
3067+
_LOGGER.debug(
3068+
"Diagonal size is zero, resulting in an empty tensor which is not supported for this operation."
3069+
)
3070+
return False
3071+
else:
3072+
return True
3073+
3074+
3075+
@dynamo_tensorrt_converter(
3076+
torch.ops.aten.diagonal.default, capability_validator=zero_diag_size_validator
3077+
)
3078+
def aten_ops_diagonal(
3079+
ctx: ConversionContext,
3080+
target: Target,
3081+
args: Tuple[Argument, ...],
3082+
kwargs: Dict[str, Argument],
3083+
name: str,
3084+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3085+
return impl.slice.diagonal(
3086+
ctx,
3087+
target,
3088+
SourceIR.ATEN,
3089+
name,
3090+
args[0],
3091+
args_bounds_check(args, 1, replacement=0),
3092+
args_bounds_check(args, 2, replacement=0),
3093+
args_bounds_check(args, 3, replacement=1),
3094+
)
3095+
3096+
30393097
@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
30403098
def aten_ops_scalar_tensor(
30413099
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,34 @@ def set_item(
653653
0,
654654
)
655655
return ans
656+
657+
658+
def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
659+
"""
660+
Calculate the strides for a given shape of a multi-dimensional array.
661+
662+
The output stride for each dimension indicates the number of elements to skip in
663+
memory to move to the next element along that dimension. The last dimension always
664+
has a stride of 1 because elements are stored contiguously along this dimension.
665+
666+
Example:
667+
For a 3-dimensional array with shape [2, 3, 4]:
668+
- shape = [2, 3, 4]
669+
- The function will calculate the strides as follows:
670+
1. Initialize strides: [1, 1, 1]
671+
2. Calculate strides for each dimension from right to left:
672+
- For i = 1: strides[1] = strides[2] * shape[2] = 1 * 4 = 4
673+
- For i = 0: strides[0] = strides[1] * shape[1] = 4 * 3 = 12
674+
- Final strides: [12, 4, 1]
675+
676+
Therefore, the output will be [12, 4, 1].
677+
678+
This means:
679+
- To move along the first dimension, skip 12 elements.
680+
- To move along the second dimension, skip 4 elements.
681+
- To move along the third dimension, skip 1 element.
682+
"""
683+
strides = [1] * len(shape)
684+
for i in range(len(shape) - 2, -1, -1):
685+
strides[i] = strides[i + 1] * shape[i + 1]
686+
return strides

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch_tensorrt.dynamo.conversion import impl
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
calculate_strides,
1112
flatten_dims,
1213
get_positive_dim,
1314
get_trt_tensor,
@@ -262,6 +263,68 @@ def flip(
262263
return layer.get_output(0)
263264

264265

266+
def diagonal(
267+
ctx: ConversionContext,
268+
target: Target,
269+
source_ir: Optional[SourceIR],
270+
name: str,
271+
input: TRTTensor,
272+
offset: int,
273+
dim1: int,
274+
dim2: int,
275+
) -> TRTTensor:
276+
"""
277+
This implementation is inspired by the reference implementation in PyTorch:
278+
https://github.com/pytorch/pytorch/blob/082251e76b93b277ff2791d0e2b64934add34644/torch/_refs/__init__.py#L4255
279+
"""
280+
input_shape = input.shape
281+
num_dims = len(input_shape)
282+
283+
# Adjust dimensions to be positive and canonicalize
284+
dim1 = get_positive_dim(dim1, num_dims)
285+
dim2 = get_positive_dim(dim2, num_dims)
286+
287+
# Calculate the size of the diagonal
288+
if offset >= 0:
289+
diag_size = max(min(input_shape[dim1], input_shape[dim2] - offset), 0)
290+
else:
291+
diag_size = max(min(input_shape[dim1] + offset, input_shape[dim2]), 0)
292+
293+
if diag_size == 0:
294+
raise ValueError("The size of the diagonal is non-positive.")
295+
296+
strides = calculate_strides(input_shape)
297+
298+
# Compute the storage offset
299+
storage_offset = 0
300+
if offset >= 0:
301+
storage_offset += offset * strides[dim2]
302+
else:
303+
storage_offset -= offset * strides[dim1]
304+
305+
# Calculate new sizes and strides for as_strided
306+
sizes = [s for i, s in enumerate(input_shape) if i not in (dim1, dim2)]
307+
sizes.append(diag_size)
308+
309+
input_strides = [s for i, s in enumerate(strides) if i not in (dim1, dim2)]
310+
new_stride = strides[dim1] + strides[dim2]
311+
input_strides.append(new_stride)
312+
313+
# Use as_strided to get the diagonal elements
314+
diagonal_output = as_strided(
315+
ctx,
316+
target,
317+
source_ir,
318+
f"{name}_as_strided",
319+
input,
320+
sizes,
321+
input_strides,
322+
storage_offset,
323+
)
324+
325+
return diagonal_output
326+
327+
265328
def as_strided(
266329
ctx: ConversionContext,
267330
target: Target,
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAsStridedConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
(3, 3),
14+
1,
15+
0,
16+
1,
17+
),
18+
(
19+
(3, 3),
20+
1,
21+
0,
22+
-1,
23+
),
24+
(
25+
(3, 4),
26+
1,
27+
0,
28+
1,
29+
),
30+
(
31+
(5, 4, 2),
32+
-1,
33+
1,
34+
2,
35+
),
36+
(
37+
(5, 4, 2),
38+
1,
39+
2,
40+
0,
41+
),
42+
(
43+
(6, 5, 4),
44+
1,
45+
0,
46+
1,
47+
),
48+
(
49+
(2, 5, 4, 2),
50+
0,
51+
0,
52+
1,
53+
),
54+
(
55+
(2, 5, 4, 2),
56+
1,
57+
1,
58+
2,
59+
),
60+
(
61+
(2, 5, 4, 2),
62+
1,
63+
-1,
64+
2,
65+
),
66+
(
67+
(2, 5, 4, 2),
68+
1,
69+
1,
70+
-2,
71+
),
72+
(
73+
(2, 5, 4, 2),
74+
1,
75+
-1,
76+
-2,
77+
),
78+
(
79+
(2, 5, 4, 2),
80+
0,
81+
0,
82+
2,
83+
),
84+
(
85+
(2, 5, 4, 2),
86+
-1,
87+
1,
88+
2,
89+
),
90+
(
91+
(2, 5, 4, 2, 6),
92+
1,
93+
1,
94+
2,
95+
),
96+
(
97+
(2, 5, 4, 2, 5, 6),
98+
1,
99+
1,
100+
2,
101+
),
102+
]
103+
)
104+
def test_diagonal(
105+
self,
106+
input_shape,
107+
offset,
108+
dim1,
109+
dim2,
110+
):
111+
class TestModule(torch.nn.Module):
112+
def forward(self, x):
113+
return torch.ops.aten.diagonal.default(x, offset, dim1, dim2)
114+
115+
inputs = [torch.randn(input_shape)]
116+
self.run_test(
117+
TestModule(),
118+
inputs,
119+
enable_passes=True,
120+
)
121+
122+
123+
if __name__ == "__main__":
124+
run_tests()

0 commit comments

Comments
 (0)