-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Propagate alignment from vector to llvm dialects. #153482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Propagate alignment from vector to llvm dialects. #153482
Conversation
88979e9
to
514cde7
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
514cde7
to
3b48a3a
Compare
3b48a3a
to
69c6fdf
Compare
@llvm/pr-subscribers-mlir Author: Erick Ochoa Lopez (amd-eochoalo) ChangesAllows alignment to be propagated correctly from vector to LLVM dialect operations. Full diff: https://github.com/llvm/llvm-project/pull/153482.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index afc3d1b12ac0d..7d29750ddcf39 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -299,8 +299,9 @@ class VectorGatherOpConversion
}
// Resolve alignment.
- unsigned align;
- if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ unsigned align = gather.getAlignment().value_or(0);
+ if (!align &&
+ failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
@@ -354,8 +355,9 @@ class VectorScatterOpConversion
}
// Resolve alignment.
- unsigned align;
- if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ unsigned align = scatter.getAlignment().value_or(0);
+ if (!align &&
+ failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(scatter,
"could not resolve alignment");
@@ -399,8 +401,14 @@ class VectorExpandLoadOpConversion
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
+ // From:
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // The pointer alignment defaults to 1.
+ uint64_t alignment = expand.getAlignment().value_or(1);
+
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
+ expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
+ alignment);
return success();
}
};
@@ -421,8 +429,13 @@ class VectorCompressStoreOpConversion
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
+ // From:
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ // The pointer alignment defaults to 1.
+ uint64_t alignment = compress.getAlignment().value_or(1);
+
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
+ compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 9b57b1b6fb4c7..5973c2ba2cbd0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
// -----
+func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> {
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_with_alignment
+// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//
@@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
+// -----
+
+func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) {
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_alignment
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+
// -----
//===----------------------------------------------------------------------===//
@@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %a
// -----
+func.func @expand_load_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> {
+ %0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex> into vector<11xindex>
+ return %0 : vector<11xindex>
+}
+// CHECK-LABEL: func @expand_load_op_with_alignment
+// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64>
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.compressstore
//===----------------------------------------------------------------------===//
@@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>,
// -----
+func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) {
+ vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex>
+ return
+}
+// CHECK-LABEL: func @compress_store_op_with_alignment
+// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> ()
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.splat
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
unsigned align = gather.getAlignment().value_or(0); | ||
if (!align && | ||
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this means that explicit alignment takes priority over e.g. --convert-vector-to-llvm='use-vector-alignment=1
. This feels like the right design decision, but we should make sure that it is documented (perhaps just add a comment here?) and tested (e.g. in "use-vector-alignment.mlir").
Similar comment for scatter. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @banach-space! See here for the changes:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Sorry I didn't replay earlier, I was on a sick leave.
Alignment can be specified either explicitly via attributes in vector operations or via the option to use-vector-alignment=<N> in the --convert-vector-to-llvm pass. The explicit attribute takes precedent over the option used as input to the pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Allows alignment to be propagated correctly from vector to LLVM dialect operations.