Skip to content

Commit b5f6ce6

Browse files
authored
[mlir][vector] Propagate alignment from vector to llvm dialects. (#153482)
Allows alignment to be propagated correctly from vector to LLVM dialect operations.
1 parent fdace1c commit b5f6ce6

File tree

4 files changed

+141
-7
lines changed

4 files changed

+141
-7
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14571457
"operations instead of the alignment of the element type of the "
14581458
"memref. This flag is intended for use with hardware which requires"
14591459
"vector alignment, or in application contexts where it is known all "
1460-
"vector access are naturally aligned. ">,
1460+
"vector access are naturally aligned. If operations have an "
1461+
"alignment attribute set, the alignment attribute takes priority "
1462+
"over this option ">,
14611463
Option<"amx", "enable-amx",
14621464
"bool", /*default=*/"false",
14631465
"Enables the use of AMX dialect while lowering the vector "

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
247247
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248248

249249
// Resolve alignment.
250+
// Explicit alignment takes priority over use-vector-alignment.
250251
unsigned align = loadOrStoreOp.getAlignment().value_or(0);
251252
if (!align &&
252253
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
@@ -299,8 +300,10 @@ class VectorGatherOpConversion
299300
}
300301

301302
// Resolve alignment.
302-
unsigned align;
303-
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
303+
// Explicit alignment takes priority over use-vector-alignment.
304+
unsigned align = gather.getAlignment().value_or(0);
305+
if (!align &&
306+
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
304307
memRefType, align, useVectorAlignment)))
305308
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
306309

@@ -354,8 +357,10 @@ class VectorScatterOpConversion
354357
}
355358

356359
// Resolve alignment.
357-
unsigned align;
358-
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
360+
// Explicit alignment takes priority over use-vector-alignment.
361+
unsigned align = scatter.getAlignment().value_or(0);
362+
if (!align &&
363+
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
359364
memRefType, align, useVectorAlignment)))
360365
return rewriter.notifyMatchFailure(scatter,
361366
"could not resolve alignment");
@@ -399,8 +404,14 @@ class VectorExpandLoadOpConversion
399404
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
400405
adaptor.getBase(), adaptor.getIndices());
401406

407+
// From:
408+
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
409+
// The pointer alignment defaults to 1.
410+
uint64_t alignment = expand.getAlignment().value_or(1);
411+
402412
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
403-
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
413+
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
414+
alignment);
404415
return success();
405416
}
406417
};
@@ -421,8 +432,13 @@ class VectorCompressStoreOpConversion
421432
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
422433
adaptor.getBase(), adaptor.getIndices());
423434

435+
// From:
436+
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
437+
// The pointer alignment defaults to 1.
438+
uint64_t alignment = compress.getAlignment().value_or(1);
439+
424440
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
425-
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
441+
compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
426442
return success();
427443
}
428444
};

mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8
1818

1919
// -----
2020

21+
func.func @load_with_alignment_attribute(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
22+
%0 = vector.load %base[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<8xf32>
23+
return %0 : vector<8xf32>
24+
}
25+
26+
// ALL-LABEL: func @load_with_alignment_attribute
27+
28+
// VEC-ALIGN: llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>
29+
// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>
30+
31+
// -----
32+
2133
//===----------------------------------------------------------------------===//
2234
// vector.store
2335
//===----------------------------------------------------------------------===//
@@ -35,6 +47,19 @@ func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
3547

3648
// -----
3749

50+
func.func @store_with_alignment_attribute(%base : memref<200x100xf32>, %i : index, %j : index) {
51+
%val = arith.constant dense<11.0> : vector<4xf32>
52+
vector.store %val, %base[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32>
53+
return
54+
}
55+
56+
// ALL-LABEL: func @store_with_alignment_attribute
57+
58+
// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr
59+
// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr
60+
61+
// -----
62+
3863
//===----------------------------------------------------------------------===//
3964
// vector.maskedload
4065
//===----------------------------------------------------------------------===//
@@ -52,6 +77,19 @@ func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: ve
5277

5378
// -----
5479

80+
func.func @masked_load_with_alignment_attribute(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
81+
%c0 = arith.constant 0: index
82+
%0 = vector.maskedload %base[%c0], %mask, %passthru {alignment = 8} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
83+
return %0 : vector<16xf32>
84+
}
85+
86+
// ALL-LABEL: func @masked_load_with_alignment_attribute
87+
88+
// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
89+
// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
90+
91+
// -----
92+
5593
//===----------------------------------------------------------------------===//
5694
// vector.maskedstore
5795
//===----------------------------------------------------------------------===//
@@ -69,6 +107,19 @@ func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: v
69107

70108
// -----
71109

110+
func.func @masked_store_with_alignment_attribute(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
111+
%c0 = arith.constant 0: index
112+
vector.maskedstore %base[%c0], %mask, %passthru {alignment = 8} : memref<?xf32>, vector<16xi1>, vector<16xf32>
113+
return
114+
}
115+
116+
// ALL-LABEL: func @masked_store_with_alignment_attribute
117+
118+
// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
119+
// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
120+
121+
// -----
122+
72123
//===----------------------------------------------------------------------===//
73124
// vector.scatter
74125
//===----------------------------------------------------------------------===//
@@ -86,6 +137,19 @@ func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3x
86137

87138
// -----
88139

140+
func.func @scatter_with_alignment_attribute(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
141+
%0 = arith.constant 0: index
142+
vector.scatter %base[%0][%index], %mask, %value {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
143+
return
144+
}
145+
146+
// ALL-LABEL: func @scatter_with_alignment_attribute
147+
148+
// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
149+
// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
150+
151+
// -----
152+
89153
//===----------------------------------------------------------------------===//
90154
// vector.gather
91155
//===----------------------------------------------------------------------===//
@@ -100,3 +164,16 @@ func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi
100164

101165
// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
102166
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
167+
168+
// -----
169+
170+
func.func @gather_with_alignment_attribute(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
171+
%0 = arith.constant 0: index
172+
%1 = vector.gather %base[%0][%index], %mask, %passthru {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
173+
return %1 : vector<3xf32>
174+
}
175+
176+
// ALL-LABEL: func @gather_with_alignment_attribute
177+
178+
// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
179+
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
20422042

20432043
// -----
20442044

2045+
func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> {
2046+
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
2047+
return %1 : vector<3xf32>
2048+
}
2049+
2050+
// CHECK-LABEL: func @gather_with_alignment
2051+
// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2052+
2053+
// -----
2054+
20452055
//===----------------------------------------------------------------------===//
20462056
// vector.scatter
20472057
//===----------------------------------------------------------------------===//
@@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
21182128
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
21192129
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
21202130

2131+
// -----
2132+
2133+
func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) {
2134+
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
2135+
return
2136+
}
2137+
2138+
// CHECK-LABEL: func @scatter_with_alignment
2139+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
2140+
2141+
21212142
// -----
21222143

21232144
//===----------------------------------------------------------------------===//
@@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %a
21492170

21502171
// -----
21512172

2173+
func.func @expand_load_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> {
2174+
%0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex> into vector<11xindex>
2175+
return %0 : vector<11xindex>
2176+
}
2177+
// CHECK-LABEL: func @expand_load_op_with_alignment
2178+
// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64>
2179+
2180+
// -----
2181+
21522182
//===----------------------------------------------------------------------===//
21532183
// vector.compressstore
21542184
//===----------------------------------------------------------------------===//
@@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>,
21772207

21782208
// -----
21792209

2210+
func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) {
2211+
vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex>
2212+
return
2213+
}
2214+
// CHECK-LABEL: func @compress_store_op_with_alignment
2215+
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> ()
2216+
2217+
// -----
2218+
21802219
//===----------------------------------------------------------------------===//
21812220
// vector.splat
21822221
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)