Skip to content

Commit cb89457

Browse files
authored
[nlir][vector] Constrain ContractionOpToMatmulOpLowering (#102225)
Disables `ContractionOpToMatmulOpLowering` for scalable vectors. This pattern is meant to enable lowering to `llvm.matrix.multiply` - I'm not aware of any use of that in the context of scalable vectors.
1 parent 9dae7fc commit cb89457

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
12831283
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
12841284
/// vector.transpose operations are inserted if the vector.contract op is not a
12851285
/// row-major matrix multiply.
1286+
///
1287+
/// Scalable vectors are not supported.
12861288
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
12871289
vector::ContractionOp op, MaskingOpInterface maskOp,
12881290
PatternRewriter &rew) const {
@@ -1302,13 +1304,18 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13021304
!isReductionIterator(iteratorTypes[2]))
13031305
return failure();
13041306

1307+
Type opResType = op.getType();
1308+
VectorType vecType = dyn_cast<VectorType>(opResType);
1309+
if (vecType && vecType.isScalable()) {
1310+
// Note - this is sufficient to reject all cases with scalable vectors.
1311+
return failure();
1312+
}
1313+
13051314
Type elementType = op.getLhsType().getElementType();
13061315
if (!elementType.isIntOrFloat())
13071316
return failure();
13081317

1309-
Type dstElementType = op.getType();
1310-
if (auto vecType = dyn_cast<VectorType>(dstElementType))
1311-
dstElementType = vecType.getElementType();
1318+
Type dstElementType = vecType ? vecType.getElementType() : opResType;
13121319
if (elementType != dstElementType)
13131320
return failure();
13141321

mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,23 @@
3636
// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
3737
// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
3838
func.func @matmul(%arg0: vector<2x4xf32>,
39-
%arg1: vector<4x3xf32>,
40-
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
39+
%arg1: vector<4x3xf32>,
40+
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
4141
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
4242
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
4343
return %0 : vector<2x3xf32>
4444
}
4545

46+
// CHECK-LABEL: func @matmul_scalable
47+
// CHECK-NOT: vector.matrix_multiply
48+
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
49+
%arg1: vector<4x[3]xf32>,
50+
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
51+
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
52+
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
53+
return %0 : vector<2x[3]xf32>
54+
}
55+
4656
module attributes {transform.with_named_sequence} {
4757
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
4858
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)