Skip to content

[nlir][vector] Constrain ContractionOpToMatmulOpLowering #102225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2024

Conversation

banach-space
Copy link
Contributor

Disables ContractionOpToMatmulOpLowering for scalable vectors. This
pattern is meant to enable lowering to llvm.matrix.multiply - I'm not
aware of any use of that in the context of scalable vectors.

@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Disables ContractionOpToMatmulOpLowering for scalable vectors. This
pattern is meant to enable lowering to llvm.matrix.multiply - I'm not
aware of any use of that in the context of scalable vectors.


Full diff: https://github.com/llvm/llvm-project/pull/102225.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+10-3)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+12-2)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 3a799ce8e0bce3..97c3af781a92a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1283,6 +1283,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
 /// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
     vector::ContractionOp op, MaskingOpInterface maskOp,
     PatternRewriter &rew) const {
@@ -1301,14 +1303,19 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
       !isParallelIterator(iteratorTypes[1]) ||
       !isReductionIterator(iteratorTypes[2]))
     return failure();
+  
+  Type opResType = op.getType();
+  VectorType vecType = dyn_cast<VectorType>(opResType);
+  if (vecType && vecType.isScalable()) {
+    // This should be sufficient to reject all cases with scalable vectors.
+    return failure();
+  }
 
   Type elementType = op.getLhsType().getElementType();
   if (!elementType.isIntOrFloat())
     return failure();
 
-  Type dstElementType = op.getType();
-  if (auto vecType = dyn_cast<VectorType>(dstElementType))
-    dstElementType = vecType.getElementType();
+  Type dstElementType = vecType ? vecType.getElementType() : opResType;
   if (elementType != dstElementType)
     return failure();
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 78cf82e1ab6c1a..4867a416e5d144 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -36,13 +36,23 @@
 //      CHECK:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
 //      CHECK:  %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
 func.func @matmul(%arg0: vector<2x4xf32>,
-                          %arg1: vector<4x3xf32>,
-                          %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+                  %arg1: vector<4x3xf32>,
+                  %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
     : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-NOT: vector.matrix_multiply
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+                           %arg1: vector<4x[3]xf32>,
+                           %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

Copy link

github-actions bot commented Aug 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Disables `ContractionOpToMatmulOpLowering` for scalable vectors. This
pattern is meant to enable lowering to `llvm.matrix.multiply` - I'm not
aware of any use of that in the context of scalable vectors.
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left one minor comment but otherwise LGTM cheers. I do have a wider question tho, should vector.matrix_multiply be disabled for scalable vectors at the op-level?

@banach-space
Copy link
Contributor Author

I do have a wider question tho, should vector.matrix_multiply be disabled for scalable vectors at the op-level?

+1 I'll do that in a follow-up.

@banach-space banach-space merged commit cb89457 into llvm:main Aug 7, 2024
5 of 7 checks passed
@nujaa
Copy link
Contributor

nujaa commented Aug 7, 2024

[nlir] 🥲

@banach-space
Copy link
Contributor Author

[nlir] 🥲

Muscle memory failing me :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants