Skip to content

Commit c2d3414

Browse files
committed
[AArch64][SME] Propagate desired ZA states in the MachineSMEABIPass
This patch adds a propagation step to the MachineSMEABIPass that propagates desired ZA states forwards (from predecessors to successors). The aim of this is to pick better ZA states for edge bundles, as when many (or all) blocks in a bundle do not have a preferred ZA state, the ZA state assigned to a bundle can be less than ideal. An important case is nested loops, where only the inner loop has a preferred ZA state. Here we'd like to propagate the ZA state up from the inner loop to the outer loops (to avoid saves/restores in any loop). Change-Id: I39f9c7d7608e2fa070be2fb88351b4d1d0079041
1 parent 2c9e14c commit c2d3414

File tree

2 files changed

+364
-17
lines changed

2 files changed

+364
-17
lines changed

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ struct MachineSMEABI : public MachineFunctionPass {
138138
}
139139

140140
void collectNeededZAStates(MachineFunction &MF, SMEAttrs);
141+
void propagateDesiredStates(MachineFunction &MF);
141142
void pickBundleZAStates(MachineFunction &MF);
142143
void insertStateChanges(MachineFunction &MF, bool IsAgnosticZA);
143144

@@ -202,8 +203,10 @@ struct MachineSMEABI : public MachineFunctionPass {
202203
};
203204

204205
struct BlockInfo {
205-
ZAState FixedEntryState{ZAState::ANY};
206206
SmallVector<InstInfo> Insts;
207+
ZAState FixedEntryState{ZAState::ANY};
208+
ZAState DesiredIncomingState{ZAState::ANY};
209+
ZAState DesiredOutgoingState{ZAState::ANY};
207210
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
208211
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
209212
};
@@ -294,28 +297,74 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
294297

295298
// Reverse vector (as we had to iterate backwards for liveness).
296299
std::reverse(Block.Insts.begin(), Block.Insts.end());
300+
301+
// Record the desired states on entry/exit of this block. These are the
302+
// states that would not incur a state transition.
303+
if (!Block.Insts.empty()) {
304+
Block.DesiredIncomingState = Block.Insts.front().NeededState;
305+
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
306+
}
307+
}
308+
}
309+
310+
void MachineSMEABI::propagateDesiredStates(MachineFunction &MF) {
311+
// This propagates desired states from predecessors to successors. This
312+
// propagates state up loop nests (as an inner loop is a predecessor
313+
// to outer its loops).
314+
SmallVector<MachineBasicBlock *> Worklist;
315+
for (auto [BlockID, BlockInfo] : enumerate(State.Blocks)) {
316+
if (!isLegalEdgeBundleZAState(BlockInfo.DesiredIncomingState))
317+
Worklist.push_back(MF.getBlockNumbered(BlockID));
318+
}
319+
320+
while (!Worklist.empty()) {
321+
MachineBasicBlock *MBB = Worklist.pop_back_val();
322+
auto &BlockInfo = State.Blocks[MBB->getNumber()];
323+
324+
// Pick a legal edge bundle state that matches the majority of predecessors.
325+
int PredStateCounts[ZAState::NUM_ZA_STATE] = {0};
326+
for (MachineBasicBlock *Pred : predecessors(MBB)) {
327+
auto &PredBlockInfo = State.Blocks[Pred->getNumber()];
328+
if (isLegalEdgeBundleZAState(PredBlockInfo.DesiredOutgoingState))
329+
PredStateCounts[PredBlockInfo.DesiredOutgoingState]++;
330+
}
331+
ZAState PropagatedState =
332+
ZAState(max_element(PredStateCounts) - PredStateCounts);
333+
334+
if (PropagatedState != BlockInfo.DesiredIncomingState) {
335+
BlockInfo.DesiredIncomingState = PropagatedState;
336+
// Propagate to outgoing state for blocks that don't care about their
337+
// ZA state.
338+
if (BlockInfo.DesiredOutgoingState == ZAState::ANY)
339+
BlockInfo.DesiredOutgoingState = PropagatedState;
340+
341+
// Push any successors that may need updating to the worklist.
342+
for (MachineBasicBlock *Succ : successors(MBB)) {
343+
auto &SuccBlockInfo = State.Blocks[Succ->getNumber()];
344+
if (!isLegalEdgeBundleZAState(SuccBlockInfo.DesiredIncomingState))
345+
Worklist.push_back(Succ);
346+
}
347+
}
297348
}
298349
}
299350

300351
void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
301352
State.BundleStates.resize(Bundles->getNumBundles());
353+
354+
if (OptLevel != CodeGenOptLevel::None)
355+
propagateDesiredStates(MF);
356+
302357
for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
303358
LLVM_DEBUG(dbgs() << "Picking ZA state for edge bundle: " << I << '\n');
304359

305360
// Attempt to pick a ZA state for this bundle that minimizes state
306361
// transitions. Edges within loops are given a higher weight as we assume
307362
// they will be executed more than once.
308-
// TODO: We should propagate desired incoming/outgoing states through blocks
309-
// that have the "ANY" state first to make better global decisions.
310363
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
311364
for (unsigned BlockID : Bundles->getBlocks(I)) {
312365
LLVM_DEBUG(dbgs() << "- bb." << BlockID);
313366

314367
BlockInfo &Block = State.Blocks[BlockID];
315-
if (Block.Insts.empty()) {
316-
LLVM_DEBUG(dbgs() << " (no state preference)\n");
317-
continue;
318-
}
319368
bool IsLoop = MLI && MLI->getLoopFor(MF.getBlockNumbered(BlockID));
320369
bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
321370
bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
@@ -324,26 +373,28 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
324373
LLVM_DEBUG(dbgs() << " IsLoop");
325374

326375
LLVM_DEBUG(dbgs() << " (EdgeWeight: " << EdgeWeight << ')');
327-
ZAState DesiredIncomingState = Block.Insts.front().NeededState;
328-
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
329-
EdgeStateCounts[DesiredIncomingState] += EdgeWeight;
376+
bool LegalInEdge =
377+
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
378+
bool LegalOutEgde =
379+
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
380+
if (LegalInEdge) {
330381
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
331-
<< getZAStateString(DesiredIncomingState));
382+
<< getZAStateString(Block.DesiredIncomingState));
383+
EdgeStateCounts[Block.DesiredIncomingState] += EdgeWeight;
332384
}
333-
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
334-
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
335-
EdgeStateCounts[DesiredOutgoingState] += EdgeWeight;
385+
if (LegalOutEgde) {
336386
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
337-
<< getZAStateString(DesiredOutgoingState));
387+
<< getZAStateString(Block.DesiredOutgoingState));
388+
EdgeStateCounts[Block.DesiredOutgoingState] += EdgeWeight;
338389
}
390+
if (!LegalInEdge && !LegalOutEgde)
391+
LLVM_DEBUG(dbgs() << " (no state preference)");
339392
LLVM_DEBUG(dbgs() << '\n');
340393
}
341394

342395
ZAState BundleState =
343396
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
344397

345-
// Force ZA to be active in bundles that don't have a preferred state.
346-
// TODO: Something better here (to avoid extra mode switches).
347398
if (BundleState == ZAState::ANY)
348399
BundleState = ZAState::ACTIVE;
349400

0 commit comments

Comments
 (0)