Skip to content

Commit 7551eee

Browse files
apbosegs-olive
authored andcommitted
converter reorg and matmul
Matmul issue fixes and lint error check moving matmul to individual file
1 parent 9bbdc9e commit 7551eee

File tree

4 files changed

+182
-1
lines changed

4 files changed

+182
-1
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
3232
from torch_tensorrt.fx.converters.impl.select import select
3333
from torch_tensorrt.fx.converters.impl.slice import slice_op
34+
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
3435

3536
_LOGGER: logging.Logger = logging.getLogger(__name__)
3637

@@ -239,7 +240,6 @@ def aten_ops_hardtanh(
239240
kwargs: Dict[str, Argument],
240241
name: str,
241242
) -> Union[TRTTensor, Sequence[TRTTensor]]:
242-
243243
return activation.hardtanh(
244244
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
245245
)
@@ -262,6 +262,18 @@ def aten_ops_gelu(
262262
)
263263

264264

265+
@tensorrt_converter(torch.ops.aten.matmul)
266+
@tensorrt_converter(torch.ops.aten.mm.default)
267+
def aten_ops_matmul(
268+
network: TRTNetwork,
269+
target: Target,
270+
args: Tuple[Argument, ...],
271+
kwargs: Dict[str, Argument],
272+
name: str,
273+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
274+
return matrix_multiply(network, target, SourceIR.ATEN, name, args[0], args[1])
275+
276+
265277
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
266278
def aten_ops_fmod(
267279
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from torch_tensorrt.fx.converters.converter_utils import (
1515
SourceIR,
1616
get_trt_tensor,
17+
broadcast,
18+
set_layer_name,
1719
)
1820

1921
from torch_tensorrt.fx.converters.impl.elementwise.base import (
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast, Any
4+
5+
import numpy as np
6+
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Target
10+
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
12+
from torch_tensorrt.fx.utils import torch_dtype_from_trt
13+
14+
from torch_tensorrt.fx.converters.converter_utils import (
15+
SourceIR,
16+
get_trt_tensor,
17+
broadcast,
18+
set_layer_name,
19+
)
20+
21+
22+
def matrix_multiply(
23+
network: TRTNetwork,
24+
target: Target,
25+
source_ir: Optional[SourceIR],
26+
name: str,
27+
input: TRTTensor,
28+
other: TRTTensor,
29+
) -> TRTTensor:
30+
if not isinstance(input, trt.tensorrt.ITensor):
31+
input = get_trt_tensor(network, input, f"{name}_input")
32+
if not isinstance(other, trt.tensorrt.ITensor):
33+
other = get_trt_tensor(
34+
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
35+
)
36+
37+
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
38+
preset_diff = 0
39+
40+
if len(input.shape) == 1:
41+
preset_diff -= 1
42+
input_matrix_op = trt.MatrixOperation.VECTOR
43+
44+
if len(other.shape) == 1:
45+
preset_diff += 1
46+
other_matrix_op = trt.MatrixOperation.VECTOR
47+
48+
input, other = broadcast(
49+
network, input, other, f"{name}_input", f"{name}_other", preset_diff
50+
)
51+
layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
52+
set_layer_name(layer, target, name)
53+
return layer.get_output(0)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
5+
from parameterized import param, parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
12+
from parameterized import parameterized
13+
from torch.testing._internal.common_utils import run_tests
14+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
15+
16+
17+
class TestMatMulConverter(DispatchTestCase):
18+
@parameterized.expand(
19+
[
20+
("2_2", (2, 3), (3, 2)),
21+
("2_2", (2, 3), (3, 1)),
22+
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
23+
# (2,3), (3,) torch.ops.aten.mv.default
24+
# Following cases use torch.ops.aten.bmm.defauly
25+
# ("4_3", (3,1,3,2), (2,2,3)),
26+
# ("3_4", (3,1,3,2), (2,2,3)),
27+
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
28+
# ("4_2", (1, 2, 2, 3), (3, 2)),
29+
]
30+
)
31+
def test_matmul_other_constant(self, _, input_shape, other_shape):
32+
class MatMul(nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
self.other = nn.Parameter(torch.randn(*other_shape))
36+
37+
def forward(self, input):
38+
return torch.matmul(input, self.other)
39+
40+
inputs = [torch.randn(*input_shape)]
41+
42+
self.run_test(
43+
MatMul(),
44+
inputs,
45+
expected_ops={torch.ops.aten.mm.default},
46+
test_explicit_batch_dim=(len(input_shape) >= 1),
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
("2_2", (2, 3), (3, 2)),
52+
("1_2", (1, 3), (3, 2)),
53+
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
54+
# (2,3), (3,) torch.ops.aten.mv.default
55+
# Following cases use torch.ops.aten.bmm.defauly
56+
# ("4_3", (3,1,3,2), (2,2,3)),
57+
# ("3_4", (3,1,3,2), (2,2,3)),
58+
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
59+
# ("4_2", (1, 2, 2, 3), (3, 2)),
60+
]
61+
)
62+
def test_matmul_input_constant(self, _, input_shape, other_shape):
63+
class MatMul(nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
self.input = nn.Parameter(torch.randn(*input_shape))
67+
68+
def forward(self, other):
69+
return torch.matmul(self.input, other)
70+
71+
inputs = [torch.randn(*other_shape)]
72+
73+
self.run_test(
74+
MatMul(),
75+
inputs,
76+
expected_ops={torch.ops.aten.mm.default},
77+
test_explicit_batch_dim=True
78+
# test_explicit_batch_dim=(len(other_shape) <= 2),
79+
)
80+
81+
@parameterized.expand(
82+
[
83+
("2_2", (2, 3), (3, 2)),
84+
# ("2_3", (2, 3), (2, 3, 4)),
85+
# ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)),
86+
# ("4_2", (2, 1, 2, 3), (3, 2)),
87+
# ("2_1", (2, 3), (3,)),
88+
# ("1_2", (3,), (3, 2)),
89+
# ("1_1", (3,), (3,)),
90+
]
91+
)
92+
def test_matmul(self, _, input_shape, other_shape):
93+
class MatMul(nn.Module):
94+
def forward(self, input, other):
95+
return torch.matmul(input, other)
96+
97+
inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
98+
test_explicit_batch_dim = not (
99+
input_shape[0] == other_shape[0]
100+
and len(input_shape) > 2
101+
and len(other_shape) > 2
102+
)
103+
self.run_test(
104+
MatMul(),
105+
inputs,
106+
expected_ops={torch.ops.aten.mm.default},
107+
test_explicit_batch_dim=test_explicit_batch_dim,
108+
)
109+
110+
# FIXME: dynamic shape is giving bmm
111+
112+
113+
if __name__ == "__main__":
114+
run_tests()

0 commit comments

Comments
 (0)