Skip to content

Commit 72f1927

Browse files
committed
Always align SVE vectors to 16 bytes and predicates to 2 bytes
1 parent 2ba4626 commit 72f1927

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,13 @@ struct RelaxScalableVectorAllocaAlignment
8282

8383
LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
8484
PatternRewriter &rewriter) const override {
85-
auto elementType = allocaOp.getType().getElementType();
86-
auto vectorType = llvm::dyn_cast<VectorType>(elementType);
85+
auto memrefElementType = allocaOp.getType().getElementType();
86+
auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
8787
if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
8888
return failure();
8989

90-
unsigned elementByteSize =
91-
vectorType.getElementType().getIntOrFloatBitWidth() / 8;
92-
93-
unsigned aligment = std::max(1u, elementByteSize);
90+
// Set alignment based on the defaults for SVE vectors and predicates.
91+
unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
9492
allocaOp.setAlignment(aligment);
9593

9694
return success();

mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
88
// CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>)
99
func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
10-
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
10+
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
1111
%alloca = memref.alloca() : memref<vector<[1]xi1>>
1212
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
1313
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -24,7 +24,7 @@ func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vecto
2424
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
2525
// CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>)
2626
func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
27-
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
27+
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
2828
%alloca = memref.alloca() : memref<vector<[2]xi1>>
2929
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
3030
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -41,7 +41,7 @@ func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vecto
4141
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
4242
// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>)
4343
func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
44-
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
44+
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
4545
%alloca = memref.alloca() : memref<vector<[4]xi1>>
4646
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
4747
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -58,7 +58,7 @@ func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vecto
5858
// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
5959
// CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>)
6060
func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
61-
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
61+
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
6262
%alloca = memref.alloca() : memref<vector<[8]xi1>>
6363
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
6464
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
@@ -77,7 +77,7 @@ func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vecto
7777
func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
7878
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
7979
%c0 = arith.constant 0 : index
80-
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 1 : i64} : memref<vector<3x[16]xi1>>
80+
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>>
8181
%alloca = memref.alloca() : memref<vector<3x[8]xi1>>
8282
// CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
8383
// CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
@@ -95,47 +95,47 @@ func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]
9595

9696
// CHECK-LABEL: @set_sve_alloca_alignment
9797
func.func @set_sve_alloca_alignment() {
98-
// CHECK-COUNT-6: alignment = 1
98+
// CHECK-COUNT-6: alignment = 16
9999
%a1 = memref.alloca() : memref<vector<[32]xi8>>
100100
%a2 = memref.alloca() : memref<vector<[16]xi8>>
101101
%a3 = memref.alloca() : memref<vector<[8]xi8>>
102102
%a4 = memref.alloca() : memref<vector<[4]xi8>>
103103
%a5 = memref.alloca() : memref<vector<[2]xi8>>
104104
%a6 = memref.alloca() : memref<vector<[1]xi8>>
105105

106-
// CHECK-COUNT-6: alignment = 2
106+
// CHECK-COUNT-6: alignment = 16
107107
%b1 = memref.alloca() : memref<vector<[32]xi16>>
108108
%b2 = memref.alloca() : memref<vector<[16]xi16>>
109109
%b3 = memref.alloca() : memref<vector<[8]xi16>>
110110
%b4 = memref.alloca() : memref<vector<[4]xi16>>
111111
%b5 = memref.alloca() : memref<vector<[2]xi16>>
112112
%b6 = memref.alloca() : memref<vector<[1]xi16>>
113113

114-
// CHECK-COUNT-6: alignment = 4
114+
// CHECK-COUNT-6: alignment = 16
115115
%c1 = memref.alloca() : memref<vector<[32]xi32>>
116116
%c2 = memref.alloca() : memref<vector<[16]xi32>>
117117
%c3 = memref.alloca() : memref<vector<[8]xi32>>
118118
%c4 = memref.alloca() : memref<vector<[4]xi32>>
119119
%c5 = memref.alloca() : memref<vector<[2]xi32>>
120120
%c6 = memref.alloca() : memref<vector<[1]xi32>>
121121

122-
// CHECK-COUNT-6: alignment = 8
122+
// CHECK-COUNT-6: alignment = 16
123123
%d1 = memref.alloca() : memref<vector<[32]xi64>>
124124
%d2 = memref.alloca() : memref<vector<[16]xi64>>
125125
%d3 = memref.alloca() : memref<vector<[8]xi64>>
126126
%d4 = memref.alloca() : memref<vector<[4]xi64>>
127127
%d5 = memref.alloca() : memref<vector<[2]xi64>>
128128
%d6 = memref.alloca() : memref<vector<[1]xi64>>
129129

130-
// CHECK-COUNT-6: alignment = 4
130+
// CHECK-COUNT-6: alignment = 16
131131
%e1 = memref.alloca() : memref<vector<[32]xf32>>
132132
%e2 = memref.alloca() : memref<vector<[16]xf32>>
133133
%e3 = memref.alloca() : memref<vector<[8]xf32>>
134134
%e4 = memref.alloca() : memref<vector<[4]xf32>>
135135
%e5 = memref.alloca() : memref<vector<[2]xf32>>
136136
%e6 = memref.alloca() : memref<vector<[1]xf32>>
137137

138-
// CHECK-COUNT-6: alignment = 8
138+
// CHECK-COUNT-6: alignment = 16
139139
%f1 = memref.alloca() : memref<vector<[32]xf64>>
140140
%f2 = memref.alloca() : memref<vector<[16]xf64>>
141141
%f3 = memref.alloca() : memref<vector<[8]xf64>>

0 commit comments

Comments
 (0)