Skip to content

Commit b3c2183

Browse files
committed
[AArch64][SME] Propagate desired ZA states in the MachineSMEABIPass
This patch adds a propagation step to the MachineSMEABIPass that propagates desired ZA states forwards (from predecessors to successors). The aim of this is to pick better ZA states for edge bundles, as when many (or all) blocks in a bundle do not have a preferred ZA state, the ZA state assigned to a bundle can be less than ideal. An important case is nested loops, where only the inner loop has a preferred ZA state. Here we'd like to propagate the ZA state up from the inner loop to the outer loops (to avoid saves/restores in any loop). Change-Id: I39f9c7d7608e2fa070be2fb88351b4d1d0079041
1 parent 2c9e14c commit b3c2183

File tree

3 files changed

+369
-32
lines changed

3 files changed

+369
-32
lines changed

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ struct MachineSMEABI : public MachineFunctionPass {
138138
}
139139

140140
void collectNeededZAStates(MachineFunction &MF, SMEAttrs);
141+
void propagateDesiredStates(MachineFunction &MF);
141142
void pickBundleZAStates(MachineFunction &MF);
142143
void insertStateChanges(MachineFunction &MF, bool IsAgnosticZA);
143144

@@ -202,8 +203,10 @@ struct MachineSMEABI : public MachineFunctionPass {
202203
};
203204

204205
struct BlockInfo {
205-
ZAState FixedEntryState{ZAState::ANY};
206206
SmallVector<InstInfo> Insts;
207+
ZAState FixedEntryState{ZAState::ANY};
208+
ZAState DesiredIncomingState{ZAState::ANY};
209+
ZAState DesiredOutgoingState{ZAState::ANY};
207210
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
208211
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
209212
};
@@ -294,28 +297,74 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
294297

295298
// Reverse vector (as we had to iterate backwards for liveness).
296299
std::reverse(Block.Insts.begin(), Block.Insts.end());
300+
301+
// Record the desired states on entry/exit of this block. These are the
302+
// states that would not incur a state transition.
303+
if (!Block.Insts.empty()) {
304+
Block.DesiredIncomingState = Block.Insts.front().NeededState;
305+
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
306+
}
307+
}
308+
}
309+
310+
void MachineSMEABI::propagateDesiredStates(MachineFunction &MF) {
311+
// This propagates desired states from predecessors to successors. This
312+
// propagates state up loop nests (as an inner loop is a predecessor
313+
// to outer its loops).
314+
SmallVector<MachineBasicBlock *> Worklist;
315+
for (auto [BlockID, BlockInfo] : enumerate(State.Blocks)) {
316+
if (!isLegalEdgeBundleZAState(BlockInfo.DesiredIncomingState))
317+
Worklist.push_back(MF.getBlockNumbered(BlockID));
318+
}
319+
320+
while (!Worklist.empty()) {
321+
MachineBasicBlock *MBB = Worklist.pop_back_val();
322+
auto &BlockInfo = State.Blocks[MBB->getNumber()];
323+
324+
// Pick a legal edge bundle state that matches the majority of predecessors.
325+
int PredStateCounts[ZAState::NUM_ZA_STATE] = {0};
326+
for (MachineBasicBlock *Pred : predecessors(MBB)) {
327+
auto &PredBlockInfo = State.Blocks[Pred->getNumber()];
328+
if (isLegalEdgeBundleZAState(PredBlockInfo.DesiredOutgoingState))
329+
PredStateCounts[PredBlockInfo.DesiredOutgoingState]++;
330+
}
331+
ZAState PropagatedState =
332+
ZAState(max_element(PredStateCounts) - PredStateCounts);
333+
334+
if (PropagatedState != BlockInfo.DesiredIncomingState) {
335+
BlockInfo.DesiredIncomingState = PropagatedState;
336+
// Propagate to outgoing state for blocks that don't care about their
337+
// ZA state.
338+
if (BlockInfo.DesiredOutgoingState == ZAState::ANY)
339+
BlockInfo.DesiredOutgoingState = PropagatedState;
340+
341+
// Push any successors that may need updating to the worklist.
342+
for (MachineBasicBlock *Succ : successors(MBB)) {
343+
auto &SuccBlockInfo = State.Blocks[Succ->getNumber()];
344+
if (!isLegalEdgeBundleZAState(SuccBlockInfo.DesiredIncomingState))
345+
Worklist.push_back(Succ);
346+
}
347+
}
297348
}
298349
}
299350

300351
void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
301352
State.BundleStates.resize(Bundles->getNumBundles());
353+
354+
if (OptLevel != CodeGenOptLevel::None)
355+
propagateDesiredStates(MF);
356+
302357
for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
303358
LLVM_DEBUG(dbgs() << "Picking ZA state for edge bundle: " << I << '\n');
304359

305360
// Attempt to pick a ZA state for this bundle that minimizes state
306361
// transitions. Edges within loops are given a higher weight as we assume
307362
// they will be executed more than once.
308-
// TODO: We should propagate desired incoming/outgoing states through blocks
309-
// that have the "ANY" state first to make better global decisions.
310363
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
311364
for (unsigned BlockID : Bundles->getBlocks(I)) {
312365
LLVM_DEBUG(dbgs() << "- bb." << BlockID);
313366

314367
BlockInfo &Block = State.Blocks[BlockID];
315-
if (Block.Insts.empty()) {
316-
LLVM_DEBUG(dbgs() << " (no state preference)\n");
317-
continue;
318-
}
319368
bool IsLoop = MLI && MLI->getLoopFor(MF.getBlockNumbered(BlockID));
320369
bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
321370
bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
@@ -324,26 +373,28 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
324373
LLVM_DEBUG(dbgs() << " IsLoop");
325374

326375
LLVM_DEBUG(dbgs() << " (EdgeWeight: " << EdgeWeight << ')');
327-
ZAState DesiredIncomingState = Block.Insts.front().NeededState;
328-
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
329-
EdgeStateCounts[DesiredIncomingState] += EdgeWeight;
376+
bool LegalInEdge =
377+
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
378+
bool LegalOutEgde =
379+
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
380+
if (LegalInEdge) {
330381
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
331-
<< getZAStateString(DesiredIncomingState));
382+
<< getZAStateString(Block.DesiredIncomingState));
383+
EdgeStateCounts[Block.DesiredIncomingState] += EdgeWeight;
332384
}
333-
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
334-
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
335-
EdgeStateCounts[DesiredOutgoingState] += EdgeWeight;
385+
if (LegalOutEgde) {
336386
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
337-
<< getZAStateString(DesiredOutgoingState));
387+
<< getZAStateString(Block.DesiredOutgoingState));
388+
EdgeStateCounts[Block.DesiredOutgoingState] += EdgeWeight;
338389
}
390+
if (!LegalInEdge && !LegalOutEgde)
391+
LLVM_DEBUG(dbgs() << " (no state preference)");
339392
LLVM_DEBUG(dbgs() << '\n');
340393
}
341394

342395
ZAState BundleState =
343396
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
344397

345-
// Force ZA to be active in bundles that don't have a preferred state.
346-
// TODO: Something better here (to avoid extra mode switches).
347398
if (BundleState == ZAState::ANY)
348399
BundleState = ZAState::ACTIVE;
349400

llvm/test/CodeGen/AArch64/sme-za-exceptions.ll

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,27 +62,17 @@ define void @za_with_raii(i1 %fail) "aarch64_inout_za" personality ptr @__gxx_pe
6262
; CHECK-NEXT: ldr x1, [x1, :got_lo12:typeinfo_for_char_const_ptr]
6363
; CHECK-NEXT: bl __cxa_throw
6464
; CHECK-NEXT: .Ltmp1:
65-
; CHECK-NEXT: mov x8, x0
66-
; CHECK-NEXT: smstart za
67-
; CHECK-NEXT: mrs x9, TPIDR2_EL0
68-
; CHECK-NEXT: sub x0, x29, #16
69-
; CHECK-NEXT: cbnz x9, .LBB0_4
70-
; CHECK-NEXT: // %bb.3: // %throw_exception
71-
; CHECK-NEXT: bl __arm_tpidr2_restore
72-
; CHECK-NEXT: .LBB0_4: // %throw_exception
73-
; CHECK-NEXT: msr TPIDR2_EL0, xzr
74-
; CHECK-NEXT: // kill: def $x0 killed $x8
75-
; CHECK-NEXT: // %bb.5: // %throw_fail
76-
; CHECK-NEXT: .LBB0_6: // %unwind_dtors
65+
; CHECK-NEXT: // %bb.3: // %throw_fail
66+
; CHECK-NEXT: .LBB0_4: // %unwind_dtors
7767
; CHECK-NEXT: .Ltmp2:
7868
; CHECK-NEXT: mov x19, x0
7969
; CHECK-NEXT: smstart za
8070
; CHECK-NEXT: mrs x8, TPIDR2_EL0
8171
; CHECK-NEXT: sub x0, x29, #16
82-
; CHECK-NEXT: cbnz x8, .LBB0_8
83-
; CHECK-NEXT: // %bb.7: // %unwind_dtors
72+
; CHECK-NEXT: cbnz x8, .LBB0_6
73+
; CHECK-NEXT: // %bb.5: // %unwind_dtors
8474
; CHECK-NEXT: bl __arm_tpidr2_restore
85-
; CHECK-NEXT: .LBB0_8: // %unwind_dtors
75+
; CHECK-NEXT: .LBB0_6: // %unwind_dtors
8676
; CHECK-NEXT: msr TPIDR2_EL0, xzr
8777
; CHECK-NEXT: bl shared_za_call
8878
; CHECK-NEXT: sub x8, x29, #16

0 commit comments

Comments
 (0)