Skip to content

Commit 75b3c3d

Browse files
authored
[ARM] Disable UpperBound loop unrolling for MVE tail predicated loops. (llvm#69709)
For MVE tail predicated loops, better code can be generated by keeping the loop whole than to unroll to an upper bound, which requires the expansion of active lane masks that can be difficult to generate good code for. This patch disables UpperBound unrolling when we find a active_lane_mask in the loop.
1 parent 03ec84a commit 75b3c3d

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,9 +2430,15 @@ ARMTTIImpl::getPreferredTailFoldingStyle(bool IVUpdateMayOverflow) const {
24302430
void ARMTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
24312431
TTI::UnrollingPreferences &UP,
24322432
OptimizationRemarkEmitter *ORE) {
2433-
// Enable Upper bound unrolling universally, not dependant upon the conditions
2434-
// below.
2435-
UP.UpperBound = true;
2433+
// Enable Upper bound unrolling universally, providing that we do not see an
2434+
// active lane mask, which will be better kept as a loop to become tail
2435+
// predicated than to be conditionally unrolled.
2436+
UP.UpperBound =
2437+
!ST->hasMVEIntegerOps() || !any_of(*L->getHeader(), [](Instruction &I) {
2438+
return isa<IntrinsicInst>(I) &&
2439+
cast<IntrinsicInst>(I).getIntrinsicID() ==
2440+
Intrinsic::get_active_lane_mask;
2441+
});
24362442

24372443
// Only currently enable these preferences for M-Class cores.
24382444
if (!ST->isMClass())
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=loop-unroll -S -mtriple thumbv8.1m.main-none-eabi -mattr=+mve %s | FileCheck %s
3+
4+
; The vector loop here is better kept as a loop than conditionally unrolled,
5+
; letting it transform into a tail predicted loop.
6+
7+
define void @unroll_upper(ptr noundef %pSrc, ptr nocapture noundef writeonly %pDst, i32 noundef %blockSize) {
8+
; CHECK-LABEL: @unroll_upper(
9+
; CHECK-NEXT: entry:
10+
; CHECK-NEXT: [[CMP_NOT23:%.*]] = icmp ult i32 [[BLOCKSIZE:%.*]], 16
11+
; CHECK-NEXT: [[AND:%.*]] = and i32 [[BLOCKSIZE]], 15
12+
; CHECK-NEXT: [[CMP6_NOT28:%.*]] = icmp eq i32 [[AND]], 0
13+
; CHECK-NEXT: br i1 [[CMP6_NOT28]], label [[WHILE_END12:%.*]], label [[VECTOR_MEMCHECK:%.*]]
14+
; CHECK: vector.memcheck:
15+
; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[PDST:%.*]], i32 [[AND]]
16+
; CHECK-NEXT: [[TMP0:%.*]] = shl nuw nsw i32 [[AND]], 1
17+
; CHECK-NEXT: [[SCEVGEP32:%.*]] = getelementptr i8, ptr [[PSRC:%.*]], i32 [[TMP0]]
18+
; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[PDST]], [[SCEVGEP32]]
19+
; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[PSRC]], [[SCEVGEP]]
20+
; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
21+
; CHECK-NEXT: [[N_RND_UP:%.*]] = add nuw nsw i32 [[AND]], 7
22+
; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N_RND_UP]], 24
23+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
24+
; CHECK: vector.body:
25+
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_MEMCHECK]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
26+
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[PDST]], i32 [[INDEX]]
27+
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[INDEX]], 1
28+
; CHECK-NEXT: [[NEXT_GEP37:%.*]] = getelementptr i8, ptr [[PSRC]], i32 [[TMP1]]
29+
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[AND]])
30+
; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr [[NEXT_GEP37]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
31+
; CHECK-NEXT: [[TMP2:%.*]] = lshr <8 x i16> [[WIDE_MASKED_LOAD]], <i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8>
32+
; CHECK-NEXT: [[TMP3:%.*]] = trunc <8 x i16> [[TMP2]] to <8 x i8>
33+
; CHECK-NEXT: call void @llvm.masked.store.v8i8.p0(<8 x i8> [[TMP3]], ptr [[NEXT_GEP]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]])
34+
; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 8
35+
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
36+
; CHECK-NEXT: br i1 [[TMP4]], label [[WHILE_END12_LOOPEXIT:%.*]], label [[VECTOR_BODY]]
37+
; CHECK: while.end12.loopexit:
38+
; CHECK-NEXT: br label [[WHILE_END12]]
39+
; CHECK: while.end12:
40+
; CHECK-NEXT: ret void
41+
;
42+
entry:
43+
%cmp.not23 = icmp ult i32 %blockSize, 16
44+
%and = and i32 %blockSize, 15
45+
%cmp6.not28 = icmp eq i32 %and, 0
46+
br i1 %cmp6.not28, label %while.end12, label %vector.memcheck
47+
48+
vector.memcheck: ; preds = %entry
49+
%scevgep = getelementptr i8, ptr %pDst, i32 %and
50+
%0 = shl nuw nsw i32 %and, 1
51+
%scevgep32 = getelementptr i8, ptr %pSrc, i32 %0
52+
%bound0 = icmp ult ptr %pDst, %scevgep32
53+
%bound1 = icmp ult ptr %pSrc, %scevgep
54+
%found.conflict = and i1 %bound0, %bound1
55+
%n.rnd.up = add nuw nsw i32 %and, 7
56+
%n.vec = and i32 %n.rnd.up, 24
57+
br label %vector.body
58+
59+
vector.body: ; preds = %vector.body, %vector.memcheck
60+
%index = phi i32 [ 0, %vector.memcheck ], [ %index.next, %vector.body ]
61+
%next.gep = getelementptr i8, ptr %pDst, i32 %index
62+
%1 = shl i32 %index, 1
63+
%next.gep37 = getelementptr i8, ptr %pSrc, i32 %1
64+
%active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 %index, i32 %and)
65+
%wide.masked.load = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr %next.gep37, i32 2, <8 x i1> %active.lane.mask, <8 x i16> poison)
66+
%2 = lshr <8 x i16> %wide.masked.load, <i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8>
67+
%3 = trunc <8 x i16> %2 to <8 x i8>
68+
call void @llvm.masked.store.v8i8.p0(<8 x i8> %3, ptr %next.gep, i32 1, <8 x i1> %active.lane.mask)
69+
%index.next = add i32 %index, 8
70+
%4 = icmp eq i32 %index.next, %n.vec
71+
br i1 %4, label %while.end12, label %vector.body
72+
73+
while.end12: ; preds = %vector.body, %entry
74+
ret void
75+
}
76+
77+
declare <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32, i32)
78+
declare <8 x i16> @llvm.masked.load.v8i16.p0(ptr nocapture, i32 immarg, <8 x i1>, <8 x i16>)
79+
declare void @llvm.masked.store.v8i8.p0(<8 x i8>, ptr nocapture, i32 immarg, <8 x i1>)

0 commit comments

Comments
 (0)