diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6cd44f4855..99007e2a4d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -171,6 +171,7 @@ def aten_ops_gelu( @dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc] def aten_ops_matmul( network: TRTNetwork, target: Target, @@ -179,7 +180,12 @@ def aten_ops_matmul( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.matmul.matrix_multiply( - network, target, SourceIR.ATEN, name, args[0], args[1] + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 3e1bef66ef..4b69b09d2a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -1,5 +1,6 @@ from typing import Optional +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.fx.converters.converter_utils import ( @@ -10,8 +11,6 @@ from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -import tensorrt as trt - def matrix_multiply( network: TRTNetwork, @@ -20,6 +19,8 @@ def matrix_multiply( name: str, input: TRTTensor, other: TRTTensor, + input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, + other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, ) -> TRTTensor: if not isinstance(input, trt.tensorrt.ITensor): input = get_trt_tensor(network, input, f"{name}_input") @@ -31,7 +32,6 @@ def matrix_multiply( dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), ) - input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE preset_diff = 0 if len(input.shape) == 1: @@ -46,5 +46,5 @@ def matrix_multiply( network, input, other, f"{name}_input", f"{name}_other", preset_diff ) layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_matmul_aten.py b/tests/py/dynamo/conversion/test_matmul_aten.py index c0220d1808..816686c4ec 100644 --- a/tests/py/dynamo/conversion/test_matmul_aten.py +++ b/tests/py/dynamo/conversion/test_matmul_aten.py @@ -9,18 +9,44 @@ class TestMatMulConverter(DispatchTestCase): @parameterized.expand( [ - ("2_2", (2, 3), (3, 2)), - ("2_2", (2, 3), (3, 1)), - # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? - # (2,3), (3,) torch.ops.aten.mv.default - # Following cases use torch.ops.aten.bmm.defauly + ( + "2_2", + (2, 3), + (3, 2), + ), + ( + "4_6", + (4, 5), + (5, 6), + ), + ( + "2_1", + (2, 3), + (3, 1), + ), + ( + "4_1", + (4, 1), + (1, 1), + ), + ( + "1_2", + (1, 3), + (3, 2), + ), + ( + "1_3", + (1, 2), + (2, 3), + ), + # Following cases use torch.ops.aten.bmm.default # ("4_3", (3,1,3,2), (2,2,3)), # ("3_4", (3,1,3,2), (2,2,3)), # ("3_4", (2, 2, 3), (3, 1, 3, 3)), # ("4_2", (1, 2, 2, 3), (3, 2)), ] ) - def test_matmul_other_constant(self, _, input_shape, other_shape): + def test_matmul_mm(self, _, input_shape, other_shape): class MatMul(nn.Module): def __init__(self): super().__init__() @@ -39,32 +65,43 @@ def forward(self, input): @parameterized.expand( [ - ("2_2", (2, 3), (3, 2)), - ("1_2", (1, 3), (3, 2)), - # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? - # (2,3), (3,) torch.ops.aten.mv.default - # Following cases use torch.ops.aten.bmm.defauly - # ("4_3", (3,1,3,2), (2,2,3)), - # ("3_4", (3,1,3,2), (2,2,3)), - # ("3_4", (2, 2, 3), (3, 1, 3, 3)), - # ("4_2", (1, 2, 2, 3), (3, 2)), + ( + "1_1", + (1, 1), + (1,), + ), + ( + "1_1", + (1, 2), + (2,), + ), + ( + "2_1", + (2, 1), + (1,), + ), + ( + "3_1", + (3, 4), + (4,), + ), ] ) - def test_matmul_input_constant(self, _, input_shape, other_shape): + def test_matmul_mv(self, _, input_shape, other_shape): class MatMul(nn.Module): def __init__(self): super().__init__() - self.input = nn.Parameter(torch.randn(*input_shape)) + self.other = nn.Parameter(torch.randn(*other_shape)) - def forward(self, other): - return torch.matmul(self.input, other) + def forward(self, input): + return torch.matmul(input, self.other) - inputs = [torch.randn(*other_shape)] + inputs = [torch.randn(*input_shape)] self.run_test( MatMul(), inputs, - expected_ops={torch.ops.aten.mm.default}, + expected_ops={torch.ops.aten.mv.default}, ) @parameterized.expand(