Skip to content

Commit 4c19de9

Browse files
authored
[mlir][vector] Disable vector.matrix_multiply for scalable vectors (#102573)
Disables `vector.matrix_multiply` for scalable vectors. As per the docs: > This is the counterpart of llvm.matrix.multiply in MLIR I'm not aware of any use of matrix-multiply intrinsics in the context of scalable vectors, hence disabling.
1 parent 574e958 commit 4c19de9

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,13 +2688,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
26882688
TCresVTEtIsSameAsOpBase<0, 1>>]>,
26892689
Arguments<(
26902690
// TODO: tighten vector element types that make sense.
2691-
ins VectorOfRankAndType<[1],
2691+
ins FixedVectorOfRankAndType<[1],
26922692
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
2693-
VectorOfRankAndType<[1],
2693+
FixedVectorOfRankAndType<[1],
26942694
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
26952695
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
26962696
Results<(
2697-
outs VectorOfRankAndType<[1],
2697+
outs FixedVectorOfRankAndType<[1],
26982698
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
26992699
{
27002700
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
@@ -2712,7 +2712,9 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
27122712
<rhs_columns> and multiplies them. The result matrix is returned embedded in
27132713
the result vector.
27142714

2715-
Also see:
2715+
Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
2716+
support scalable vectors. Hence, this Op is only available for fixed-width
2717+
vectors. Also see:
27162718

27172719
http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
27182720

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,14 @@ class VectorOfRankAndType<list<int> allowedRanks,
494494
VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
495495
"::mlir::VectorType">;
496496

497+
// Fixed-width vector where the rank is from the given `allowedRanks` list and
498+
// the type is from the given `allowedTypes` list
499+
class FixedVectorOfRankAndType<list<int> allowedRanks,
500+
list<Type> allowedTypes> : AllOfType<
501+
[FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
502+
FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
503+
"::mlir::VectorType">;
504+
497505
// Whether the number of elements of a vector is from the given
498506
// `allowedLengths` list
499507
class IsVectorOfLengthPred<list<int> allowedLengths> :
@@ -592,7 +600,7 @@ class VectorOfLengthAndType<list<int> allowedLengths,
592600
// Any fixed-length vector where the number of elements is from the given
593601
// `allowedLengths` list and the type is from the given `allowedTypes` list
594602
class FixedVectorOfLengthAndType<list<int> allowedLengths,
595-
list<Type> allowedTypes> : AllOfType<
603+
list<Type> allowedTypes> : AllOfType<
596604
[FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
597605
FixedVectorOf<allowedTypes>.summary #
598606
FixedVectorOfLength<allowedLengths>.summary,

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,3 +1862,16 @@ func.func @invalid_step_2d() {
18621862
vector.step : vector<2x4xf32>
18631863
return
18641864
}
1865+
1866+
// -----
1867+
1868+
func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) {
1869+
// expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}}
1870+
%c = vector.matrix_multiply %a, %b {
1871+
lhs_rows = 2: i32,
1872+
lhs_columns = 2: i32 ,
1873+
rhs_columns = 2: i32 }
1874+
: (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64>
1875+
1876+
return
1877+
}

0 commit comments

Comments
 (0)