Skip to content

[mlir][vector][memref] Add alignment attribute to memory access ops #144344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jun 16, 2025

Alignment information is important to allow LLVM backends such as AMDGPU to select wide memory accesses (e.g., dwordx4 or b128). Since this info is not always inferable, it's better to inform LLVM backends explicitly about it.

This patch introduces alignment attribute to MemRef/Vector memory access ops. The propagation of these attributes to LLVM/SPIR-V will be implemented in a separate follow-up PR.

@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: None (tyb0807)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/144344.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+58-3)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+48-2)
  • (added) mlir/test/Dialect/MemRef/load-store-alignment.mlir (+27)
  • (added) mlir/test/Dialect/Vector/load-store-alignment.mlir (+27)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (modified) mlir/unittests/Dialect/MemRef/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp (+88)
  • (added) mlir/unittests/Dialect/Vector/CMakeLists.txt (+7)
  • (added) mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp (+95)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 77e3074661abf..160b04e452c5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1227,7 +1227,45 @@ def LoadOp : MemRef_Op<"load",
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, memref, indices, false, alignment);
+    }]>,
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, memref, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1924,13 +1962,30 @@ def MemRef_StoreOp : MemRef_Op<"store",
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
 
   let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
       $_state.addOperands(valueToStore);
       $_state.addOperands(memref);
-    }]>];
+    }]>
+  ];
 
   let extraClassDeclaration = [{
       Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..3cd71491bcc04 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1739,7 +1739,34 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1825,9 +1852,28 @@ def Vector_StoreOp : Vector_Op<"store"> {
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment
   );
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>
+  ];
+
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
diff --git a/mlir/test/Dialect/MemRef/load-store-alignment.mlir b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f5a5461e0ac0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: memref.load {{.*}} {alignment = 16 : i32}
+// CHECK: memref.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = memref.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  memref.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  memref.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f54d989dd190
--- /dev/null
+++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: vector.load {{.*}} {alignment = 16 : i32}
+// CHECK: vector.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = vector.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  vector.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index aea247547473d..34c9fb7317443 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -18,3 +18,4 @@ add_subdirectory(SPIRV)
 add_subdirectory(SMT)
 add_subdirectory(Transform)
 add_subdirectory(Utils)
+add_subdirectory(Vector)
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..87d33854fadcd 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  LoadStoreAlignment.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..f0b8e93c2d0e1
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
@@ -0,0 +1,88 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}
diff --git a/mlir/unittests/Dialect/Vector/CMakeLists.txt b/mlir/unittests/Dialect/Vector/CMakeLists.txt
new file mode 100644
index 0000000000000..b23d9c2df3870
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRVectorTests
+  LoadStoreAlignment.cpp
+)
+mlir_target_link_libraries(MLIRVectorTests
+  PRIVATE
+  MLIRVectorDialect
+  )
diff --git a/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..745dd8632fe4d
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
@@ -0,0 +1,95 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}

@tyb0807 tyb0807 requested review from kuhar and ftynse June 16, 2025 12:57
@krzysz00 krzysz00 changed the title Add attribute to MemRef/Vector memory access ops Add alignment attribute to MemRef/Vector memory access ops Jun 16, 2025
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs more documentation and lit tests, doesn't need the unit tests

I'd be OK with updating the LLVM (and SPIR-V, if applicable) lowerings in this patch or in a stacked-on followup

@kuhar kuhar changed the title Add alignment attribute to MemRef/Vector memory access ops [mlir][vector][memref] Add alignment attribute to MemRef/Vector memory access ops Jun 16, 2025
@kuhar kuhar changed the title [mlir][vector][memref] Add alignment attribute to MemRef/Vector memory access ops [mlir][vector][memref] Add alignment attribute to memory access ops Jun 16, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this, this will be very useful to have in the IREE codegen.

This also needs llvm/spirv lowering changes and lit tests. We don't need unit tests. See https://mlir.llvm.org/getting_started/TestingGuide/#test-categories

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jun 16, 2025

Thanks for the review. Actually, I already have all this covered in lit tests. I just wanted to make sure the new builders work as intended. I guess I can just remove the unit tests?

@tyb0807 tyb0807 requested review from krzysz00 and kuhar June 16, 2025 23:54
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to build an analysis based on memref.assume_alignment instead of adding an alignment attribute to every load/store operation? That's the approach that Triton took (AxisInfo.cpp).

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jun 17, 2025

Indeed, but I'm not sure if we can always infer the alignment of load/store op solely from the indexing maths. In case where this is not possible, we would need a way (less automatic) to specify this constraint, right?

@matthias-springer
Copy link
Member

Can you show an example where that would not work?

@banach-space
Copy link
Contributor

Do you ever expect to need something like this (different alignment for different Ops):

func.func @test_load_store_alignment(%memref: memref<4xi32>) {
  %c0 = arith.constant 0 : index
  %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
  vector.store %val, %memref[%c0] { alignment = 32 } : memref<4xi32>, vector<4xi32>
  return
}

I am just wondering, why do we need to "decorate" every Op with this attribute? And what logic is meant to take care of it? Why couldn't the alignment be a parameter that's passed to e.g. a conversion pass?

@matthias-springer
Copy link
Member

Alignment is a property of the memref SSA value. But we don't encode it in the memref type. We have memref.assume_alignment as a way to attach alignment information to an SSA value. The alignment can then be queried by a dataflow analysis.

There are two alternatives to this approach:

  1. Make the alignment information part of the memref type.
  2. Add attributes to each load/store op. (That's what this PR is doing.)

I'd like to make sure that we have a consistent story for dealing with alignment. Having both memref.assume_alignment and attributes on various ops seems a bit odd...

@kuhar
Copy link
Member

kuhar commented Jun 17, 2025

Is it possible to build an analysis based on memref.assume_alignment instead of adding an alignment attribute to every load/store operation?

Alignment is a property of the memref SSA value. But we don't encode it in the memref type.

I don't think this is the case. You can have a memref of ?xi8 that doesn't have any inherent static alignment and the alignment is really a property at each load/store op. You may end up with a memref of bytes as you lower and merge allocations etc. This is also the case with lower level IRs like llvm or spirv, e.g.: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Memory_Operands.

@matthias-springer
Copy link
Member

Good point, looks like we may want to have it on the load/store ops.

@ftynse
Copy link
Member

ftynse commented Jun 18, 2025

You can have a memref of ?xi8 that doesn't have any inherent static alignment and the alignment is really a property at each load/store op.

At that point, maybe you should use !ptr.ptr instead and not a memref. Not necessarily opposing this change, but I don't want to blindly replicate notions from lower-level abstractions like LLVM IR and SPIR-V to a higher-level abstraction.

A stronger argument may be that alignment goes both ways and we can have overaligned and underaligned accesses compared to the natural/preferred alignment of the element type, and those should be reflected somewhere, which is not necessarily a property of the type. Underaligned accesses are more interesting because those may be an optimization hint (aligned accesses are faster) or plainly forbidden by the architecture.

Note that the attribute approach is not precluding a dataflow analysis. We can have an analysis that propagates alignment information to individual operations, e.g., by looking at the structure of subscripts and attributes on previous operations accessing the same value. Attributes can be seen as a way to preserve analysis results.

Good point, looks like we may want to have it on the load/store ops.

Should we also remove memref.assume_alignment? This operation is rather confusing because nothing precludes one from using it repeatedly on the same value and the fact that it is side-effecting (so DCE doesn't remove it) without actually having side effects has been pointed out.

@@ -1217,6 +1217,11 @@ def LoadOp : MemRef_Op<"load",
be reused in the cache. For details, refer to the
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).

An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a verifier that checks for it being a power of 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants