Skip to content

Conversation

amd-eochoalo
Copy link
Contributor

@amd-eochoalo amd-eochoalo commented Aug 13, 2025

Allows alignment to be propagated correctly from vector to LLVM dialect operations.

@amd-eochoalo amd-eochoalo force-pushed the eochoa/2025-08-13/alignment-flow branch from 88979e9 to 514cde7 Compare August 13, 2025 20:34
Copy link

github-actions bot commented Aug 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@amd-eochoalo amd-eochoalo force-pushed the eochoa/2025-08-13/alignment-flow branch from 514cde7 to 3b48a3a Compare August 13, 2025 20:37
@amd-eochoalo amd-eochoalo force-pushed the eochoa/2025-08-13/alignment-flow branch from 3b48a3a to 69c6fdf Compare August 25, 2025 13:46
@amd-eochoalo amd-eochoalo marked this pull request as ready for review August 25, 2025 13:56
@llvmbot llvmbot added the mlir label Aug 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 25, 2025

@llvm/pr-subscribers-mlir

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

Allows alignment to be propagated correctly from vector to LLVM dialect operations.


Full diff: https://github.com/llvm/llvm-project/pull/153482.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+19-6)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+39)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index afc3d1b12ac0d..7d29750ddcf39 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -299,8 +299,9 @@ class VectorGatherOpConversion
     }
 
     // Resolve alignment.
-    unsigned align;
-    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+    unsigned align = gather.getAlignment().value_or(0);
+    if (!align &&
+        failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
                                         memRefType, align, useVectorAlignment)))
       return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
 
@@ -354,8 +355,9 @@ class VectorScatterOpConversion
     }
 
     // Resolve alignment.
-    unsigned align;
-    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+    unsigned align = scatter.getAlignment().value_or(0);
+    if (!align &&
+        failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
                                         memRefType, align, useVectorAlignment)))
       return rewriter.notifyMatchFailure(scatter,
                                          "could not resolve alignment");
@@ -399,8 +401,14 @@ class VectorExpandLoadOpConversion
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
                                      adaptor.getBase(), adaptor.getIndices());
 
+    // From:
+    // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+    //   The pointer alignment defaults to 1.
+    uint64_t alignment = expand.getAlignment().value_or(1);
+
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
-        expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
+        expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
+        alignment);
     return success();
   }
 };
@@ -421,8 +429,13 @@ class VectorCompressStoreOpConversion
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
                                      adaptor.getBase(), adaptor.getIndices());
 
+    // From:
+    // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+    //   The pointer alignment defaults to 1.
+    uint64_t alignment = compress.getAlignment().value_or(1);
+
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
-        compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
+        compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 9b57b1b6fb4c7..5973c2ba2cbd0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
 
 // -----
 
+func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> {
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_with_alignment
+// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.scatter
 //===----------------------------------------------------------------------===//
@@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
 // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
 
+// -----
+
+func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) {
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_alignment
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %a
 
 // -----
 
+func.func @expand_load_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> {
+  %0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex> into vector<11xindex>
+  return %0 : vector<11xindex>
+}
+// CHECK-LABEL: func @expand_load_op_with_alignment
+// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64>
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.compressstore
 //===----------------------------------------------------------------------===//
@@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>,
 
 // -----
 
+func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) {
+  vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex>
+  return
+}
+// CHECK-LABEL: func @compress_store_op_with_alignment
+// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.splat
 //===----------------------------------------------------------------------===//

@kuhar kuhar requested a review from krzysz00 August 25, 2025 19:47
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +302 to +304
unsigned align = gather.getAlignment().value_or(0);
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this means that explicit alignment takes priority over e.g. --convert-vector-to-llvm='use-vector-alignment=1. This feels like the right design decision, but we should make sure that it is documented (perhaps just add a comment here?) and tested (e.g. in "use-vector-alignment.mlir").

Similar comment for scatter. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @banach-space! See here for the changes:

b6e5aff
746412a
fdfc873

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Sorry I didn't replay earlier, I was on a sick leave.

Alignment can be specified either explicitly via
attributes in vector operations or via the option
to use-vector-alignment=<N> in the --convert-vector-to-llvm
pass. The explicit attribute takes precedent over the option
used as input to the pass.
@amd-eochoalo amd-eochoalo self-assigned this Aug 27, 2025
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@amd-eochoalo amd-eochoalo merged commit b5f6ce6 into llvm:main Sep 3, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants