-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[nlir][vector] Constrain ContractionOpToMatmulOpLowering
#102225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nlir][vector] Constrain ContractionOpToMatmulOpLowering
#102225
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesDisables Full diff: https://github.com/llvm/llvm-project/pull/102225.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 3a799ce8e0bce3..97c3af781a92a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1283,6 +1283,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
/// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rew) const {
@@ -1301,14 +1303,19 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
return failure();
+
+ Type opResType = op.getType();
+ VectorType vecType = dyn_cast<VectorType>(opResType);
+ if (vecType && vecType.isScalable()) {
+ // This should be sufficient to reject all cases with scalable vectors.
+ return failure();
+ }
Type elementType = op.getLhsType().getElementType();
if (!elementType.isIntOrFloat())
return failure();
- Type dstElementType = op.getType();
- if (auto vecType = dyn_cast<VectorType>(dstElementType))
- dstElementType = vecType.getElementType();
+ Type dstElementType = vecType ? vecType.getElementType() : opResType;
if (elementType != dstElementType)
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 78cf82e1ab6c1a..4867a416e5d144 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -36,13 +36,23 @@
// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
func.func @matmul(%arg0: vector<2x4xf32>,
- %arg1: vector<4x3xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %arg1: vector<4x3xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-NOT: vector.matrix_multiply
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+ %arg1: vector<4x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Disables `ContractionOpToMatmulOpLowering` for scalable vectors. This pattern is meant to enable lowering to `llvm.matrix.multiply` - I'm not aware of any use of that in the context of scalable vectors.
29e124e
to
31b1868
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left one minor comment but otherwise LGTM cheers. I do have a wider question tho, should vector.matrix_multiply
be disabled for scalable vectors at the op-level?
+1 I'll do that in a follow-up. |
Remove ambiguity in comment
[nlir] 🥲 |
Muscle memory failing me :( |
Disables
ContractionOpToMatmulOpLowering
for scalable vectors. Thispattern is meant to enable lowering to
llvm.matrix.multiply
- I'm notaware of any use of that in the context of scalable vectors.