Skip to content

Commit 7d9705d

Browse files
committed
feat: Add permute operation implementation
1 parent e6a503a commit 7d9705d

File tree

4 files changed

+126
-0
lines changed

4 files changed

+126
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,21 @@ def aten_ops_slice(
358358
args[3],
359359
args_bounds_check(args, 4, replacement=1),
360360
)
361+
362+
363+
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
364+
def aten_ops_slice(
365+
network: TRTNetwork,
366+
target: Target,
367+
args: Tuple[Argument, ...],
368+
kwargs: Dict[str, Argument],
369+
name: str,
370+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
371+
return impl.permutation.permute(
372+
network,
373+
target,
374+
SourceIR.ATEN,
375+
name,
376+
args[0],
377+
args[1],
378+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from . import shape
1212
from . import squeeze
1313
from . import unsqueeze
14+
from . import permutation
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Optional, Sequence, cast
2+
3+
4+
from torch.fx.node import Target
5+
6+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
7+
from torch_tensorrt.dynamo.conversion import SourceIR
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
set_layer_name,
10+
get_positive_dim,
11+
)
12+
13+
14+
def permute(
15+
network: TRTNetwork,
16+
target: Target,
17+
source_ir: Optional[SourceIR],
18+
name: str,
19+
input: TRTTensor,
20+
permutation: Sequence[int],
21+
) -> TRTTensor:
22+
if not isinstance(input, TRTTensor):
23+
raise RuntimeError(
24+
f"permute received input {input} that is not a TensorRT ITensor"
25+
)
26+
27+
permutation = [
28+
get_positive_dim(i, len(input.shape)) for i in cast(Sequence[int], permutation)
29+
]
30+
31+
layer = network.add_shuffle(input)
32+
layer.second_transpose = tuple(permutation)
33+
set_layer_name(layer, target, name, source_ir)
34+
return layer.get_output(0)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
from torch_tensorrt import Input
7+
8+
9+
class TestPermuteConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("positive", [0, 2, 1]),
13+
("negative", [0, -1, -2]),
14+
]
15+
)
16+
def test_permute_list(self, _, permutation):
17+
class Permute(nn.Module):
18+
def forward(self, x):
19+
return x.permute(permutation)
20+
21+
inputs = [torch.randn(1, 3, 2)]
22+
self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default})
23+
24+
@parameterized.expand(
25+
[
26+
("positive", [0, 2, 1]),
27+
("negative", [0, -1, -2]),
28+
]
29+
)
30+
def test_permute(self, _, permutation):
31+
class Permute(nn.Module):
32+
def forward(self, x):
33+
return x.permute(*permutation)
34+
35+
inputs = [torch.randn(1, 3, 2)]
36+
self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default})
37+
38+
def test_permute_with_dynamic_shape(self):
39+
class Permute(nn.Module):
40+
def forward(self, x):
41+
return x.permute(1, 2, 0)
42+
43+
input_specs = [
44+
Input(
45+
shape=(-1, -1, -1),
46+
dtype=torch.float32,
47+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
48+
),
49+
]
50+
self.run_test_with_dynamic_shape(
51+
Permute(), input_specs, expected_ops={torch.ops.aten.permute.default}
52+
)
53+
54+
def test_permute_with_dynamic_shape_four_dimensions(self):
55+
class Permute(nn.Module):
56+
def forward(self, x):
57+
return x.permute(1, 2, 3, 0)
58+
59+
input_specs = [
60+
Input(
61+
shape=(-1, -1, -1, -1),
62+
dtype=torch.float32,
63+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
64+
),
65+
]
66+
67+
self.run_test_with_dynamic_shape(
68+
Permute(), input_specs, expected_ops={torch.ops.aten.permute.default}
69+
)
70+
71+
72+
if __name__ == "__main__":
73+
run_tests()

0 commit comments

Comments
 (0)