Skip to content

Commit 323b821

Browse files
committed
[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass
This extends the MachineSMEABIPass to handle agnostic ZA functions. This case is currently handled like shared ZA functions, but we don't require ZA state to be reloaded before agnostic ZA calls. Note: This patch does not yet fully handle agnostic ZA functions that can catch exceptions. E.g.: ``` __arm_agnostic("sme_za_state") void try_catch_agnostic_za_callee() { try { agnostic_za_call(); } catch(...) { noexcept_agnostic_za_call(); } } ``` As in this case, we won't commit a ZA save before the `agnostic_za_call()`, which would be needed to restore ZA in the catch block. This will be handled in a later patch. Change-Id: I9cce7b42ec8b64d5442b35231b65dfaf9d149eed
1 parent a2e73ca commit 323b821

File tree

3 files changed

+332
-39
lines changed

3 files changed

+332
-39
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8244,7 +8244,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82448244
if (Subtarget->hasCustomCallingConv())
82458245
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82468246

8247-
if (Subtarget->useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
8247+
if (Subtarget->useNewSMEABILowering()) {
82488248
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
82498249
SDValue Size;
82508250
if (Attrs.hasZAState()) {
@@ -9060,9 +9060,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90609060
bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
90619061
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
90629062
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9063-
// TODO: Handle agnostic ZA functions.
9064-
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9063+
if (!UseNewSMEABILowering)
9064+
return std::nullopt;
9065+
if (IsAgnosticZAFunction) {
9066+
if (CallAttrs.requiresPreservingAllZAState())
9067+
return AArch64ISD::REQUIRES_ZA_SAVE;
90659068
return std::nullopt;
9069+
}
90669070
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
90679071
return std::nullopt;
90689072
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
@@ -9142,7 +9146,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91429146
};
91439147

91449148
bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
9145-
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9149+
bool RequiresSaveAllZA =
9150+
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
91469151
if (RequiresLazySave) {
91479152
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91489153
MachinePointerInfo MPI =

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This pass implements the SME ABI requirements for ZA state. This includes
10-
// implementing the lazy ZA state save schemes around calls.
10+
// implementing the lazy (and agnostic) ZA state save schemes around calls.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -128,7 +128,7 @@ struct MachineSMEABI : public MachineFunctionPass {
128128

129129
void collectNeededZAStates(MachineFunction &MF, SMEAttrs);
130130
void pickBundleZAStates(MachineFunction &MF);
131-
void insertStateChanges(MachineFunction &MF);
131+
void insertStateChanges(MachineFunction &MF, bool IsAgnosticZA);
132132

133133
// Emission routines for private and shared ZA functions (using lazy saves).
134134
void emitNewZAPrologue(MachineBasicBlock &MBB,
@@ -143,11 +143,46 @@ struct MachineSMEABI : public MachineFunctionPass {
143143
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
144144
bool ClearTPIDR2);
145145

146+
// Emission routines for agnostic ZA functions.
147+
void emitSetupFullZASave(MachineBasicBlock &MBB,
148+
MachineBasicBlock::iterator MBBI,
149+
LiveRegs PhysLiveRegs);
150+
void emitFullZASaveRestore(MachineBasicBlock &MBB,
151+
MachineBasicBlock::iterator MBBI,
152+
LiveRegs PhysLiveRegs, bool IsSave);
153+
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
154+
MachineBasicBlock::iterator MBBI,
155+
LiveRegs PhysLiveRegs);
156+
146157
void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
147-
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
158+
ZAState From, ZAState To, LiveRegs PhysLiveRegs,
159+
bool IsAgnosticZA);
160+
161+
// Helpers for switching between lazy/full ZA save/restore routines.
162+
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
163+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
164+
if (IsAgnosticZA)
165+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
166+
return emitSetupLazySave(MBB, MBBI);
167+
}
168+
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
169+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
170+
if (IsAgnosticZA)
171+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
172+
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
173+
}
174+
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
175+
MachineBasicBlock::iterator MBBI,
176+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
177+
if (IsAgnosticZA)
178+
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
179+
return emitAllocateLazySaveBuffer(MBB, MBBI);
180+
}
148181

149182
TPIDR2State getTPIDR2Block(MachineFunction &MF);
150183

184+
Register getAgnosticZABufferPtr(MachineFunction &MF);
185+
151186
private:
152187
struct InstInfo {
153188
ZAState NeededState{ZAState::ANY};
@@ -158,6 +193,7 @@ struct MachineSMEABI : public MachineFunctionPass {
158193
struct BlockInfo {
159194
ZAState FixedEntryState{ZAState::ANY};
160195
SmallVector<InstInfo> Insts;
196+
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
161197
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
162198
};
163199

@@ -167,6 +203,9 @@ struct MachineSMEABI : public MachineFunctionPass {
167203
SmallVector<ZAState> BundleStates;
168204
std::optional<TPIDR2State> TPIDR2Block;
169205
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
206+
Register AgnosticZABufferPtr = AArch64::NoRegister;
207+
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
208+
bool HasFullZASaveRestore = false;
170209
} State;
171210

172211
EdgeBundles *Bundles = nullptr;
@@ -175,7 +214,8 @@ struct MachineSMEABI : public MachineFunctionPass {
175214
void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
176215
SMEAttrs SMEFnAttrs) {
177216
const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
178-
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
217+
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
218+
SMEFnAttrs.hasZAState()) &&
179219
"Expected function to have ZA/ZT0 state!");
180220

181221
State.Blocks.resize(MF.getNumBlockIDs());
@@ -209,6 +249,7 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
209249

210250
Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
211251
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
252+
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
212253
for (MachineInstr &MI : reverse(MBB)) {
213254
MachineBasicBlock::iterator MBBI(MI);
214255
LiveUnits.stepBackward(MI);
@@ -219,15 +260,20 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
219260
// block setup.
220261
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
221262
State.AfterSMEProloguePt = MBBI;
263+
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
222264
}
265+
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
223266
auto [NeededState, InsertPt] = getInstNeededZAState(
224-
TRI, MI, /*ZALiveAtReturn=*/SMEFnAttrs.hasSharedZAInterface());
267+
TRI, MI, /*ZALiveAtReturn=*/SMEFnAttrs.hasSharedZAInterface() ||
268+
SMEFnAttrs.hasAgnosticZAInterface());
225269
assert((InsertPt == MBBI ||
226270
InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
227271
"Unexpected state change insertion point!");
228272
// TODO: Do something to avoid state changes where NZCV is live.
229273
if (MBBI == FirstTerminatorInsertPt)
230274
Block.PhysLiveRegsAtExit = PhysLiveRegs;
275+
if (MBBI == FirstNonPhiInsertPt)
276+
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
231277
if (NeededState != ZAState::ANY)
232278
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
233279
}
@@ -294,7 +340,7 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
294340
}
295341
}
296342

297-
void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
343+
void MachineSMEABI::insertStateChanges(MachineFunction &MF, bool IsAgnosticZA) {
298344
for (MachineBasicBlock &MBB : MF) {
299345
BlockInfo &Block = State.Blocks[MBB.getNumber()];
300346
ZAState InState =
@@ -309,7 +355,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
309355
for (auto &Inst : Block.Insts) {
310356
if (CurrentState != Inst.NeededState)
311357
emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
312-
Inst.PhysLiveRegs);
358+
Inst.PhysLiveRegs, IsAgnosticZA);
313359
CurrentState = Inst.NeededState;
314360
}
315361

@@ -318,7 +364,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
318364

319365
if (CurrentState != OutState)
320366
emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
321-
Block.PhysLiveRegsAtExit);
367+
Block.PhysLiveRegsAtExit, IsAgnosticZA);
322368
}
323369
}
324370

@@ -573,10 +619,98 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
573619
emitZeroZA(TII, DL, MBB, MBBI, /*Mask=*/0b11111111);
574620
}
575621

622+
Register MachineSMEABI::getAgnosticZABufferPtr(MachineFunction &MF) {
623+
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
624+
return State.AgnosticZABufferPtr;
625+
if (auto BufferPtr =
626+
MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
627+
BufferPtr != AArch64::NoRegister)
628+
State.AgnosticZABufferPtr = BufferPtr;
629+
else
630+
State.AgnosticZABufferPtr =
631+
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
632+
return State.AgnosticZABufferPtr;
633+
}
634+
635+
void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
636+
MachineBasicBlock::iterator MBBI,
637+
LiveRegs PhysLiveRegs, bool IsSave) {
638+
MachineFunction &MF = *MBB.getParent();
639+
auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
640+
const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
641+
const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
642+
MachineRegisterInfo &MRI = MF.getRegInfo();
643+
644+
State.HasFullZASaveRestore = true;
645+
DebugLoc DL = getDebugLoc(MBB, MBBI);
646+
Register BufferPtr = AArch64::X0;
647+
648+
ScopedPhysRegSave ScopedPhysRegSave(MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
649+
650+
// Copy the buffer pointer into X0.
651+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), BufferPtr)
652+
.addReg(getAgnosticZABufferPtr(MF));
653+
654+
// Call __arm_sme_save/__arm_sme_restore.
655+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::BL))
656+
.addReg(BufferPtr, RegState::Implicit)
657+
.addExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore")
658+
.addRegMask(TRI.getCallPreservedMask(
659+
MF,
660+
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
661+
}
662+
663+
void MachineSMEABI::emitAllocateFullZASaveBuffer(
664+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
665+
LiveRegs PhysLiveRegs) {
666+
MachineFunction &MF = *MBB.getParent();
667+
MachineFrameInfo &MFI = MF.getFrameInfo();
668+
auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
669+
const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
670+
MachineRegisterInfo &MRI = MF.getRegInfo();
671+
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
672+
673+
// Buffer already allocated in SelectionDAG.
674+
if (AFI->getEarlyAllocSMESaveBuffer())
675+
return;
676+
677+
DebugLoc DL = getDebugLoc(MBB, MBBI);
678+
Register BufferPtr = getAgnosticZABufferPtr(MF);
679+
Register BufferSize = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
680+
681+
ScopedPhysRegSave ScopedPhysRegSave(MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
682+
683+
// Calculate the SME state size.
684+
{
685+
const AArch64RegisterInfo *TRI = Subtarget.getRegisterInfo();
686+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::BL))
687+
.addExternalSymbol("__arm_sme_state_size")
688+
.addReg(AArch64::X0, RegState::ImplicitDefine)
689+
.addRegMask(TRI->getCallPreservedMask(
690+
MF, CallingConv::
691+
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
692+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), BufferSize)
693+
.addReg(AArch64::X0);
694+
}
695+
696+
// Allocate a buffer object of the size given __arm_sme_state_size.
697+
{
698+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::SUBXrx64), AArch64::SP)
699+
.addReg(AArch64::SP)
700+
.addReg(BufferSize)
701+
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
702+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), BufferPtr)
703+
.addReg(AArch64::SP);
704+
705+
// We have just allocated a variable sized object, tell this to PEI.
706+
MFI.CreateVariableSizedObject(Align(16), nullptr);
707+
}
708+
}
709+
576710
void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
577711
MachineBasicBlock::iterator InsertPt,
578712
ZAState From, ZAState To,
579-
LiveRegs PhysLiveRegs) {
713+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
580714

581715
// ZA not used.
582716
if (From == ZAState::ANY || To == ZAState::ANY)
@@ -603,10 +737,11 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
603737
}
604738

605739
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
606-
emitSetupLazySave(MBB, InsertPt);
740+
emitZASave(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
607741
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
608-
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
742+
emitZARestore(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
609743
else if (To == ZAState::OFF) {
744+
assert(!IsAgnosticZA && "Should not turn ZA off in agnostic ZA function");
610745
// If we're exiting from the CALLER_DORMANT state that means this new ZA
611746
// function did not touch ZA (so ZA was never turned on).
612747
if (From != ZAState::CALLER_DORMANT)
@@ -629,7 +764,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
629764

630765
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
631766
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
632-
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
767+
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
768+
!SMEFnAttrs.hasAgnosticZAInterface())
633769
return false;
634770

635771
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
@@ -638,20 +774,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
638774
State = PassState{};
639775
Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
640776

777+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
778+
641779
collectNeededZAStates(MF, SMEFnAttrs);
642780
pickBundleZAStates(MF);
643-
insertStateChanges(MF);
781+
insertStateChanges(MF, /*IsAgnosticZA=*/IsAgnosticZA);
644782

645783
// Allocate save buffer (if needed).
646-
if (State.TPIDR2Block.has_value()) {
784+
if (State.HasFullZASaveRestore || State.TPIDR2Block.has_value()) {
647785
if (State.AfterSMEProloguePt) {
648786
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
649787
// entry block (due to the probing loop).
650-
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
651-
*State.AfterSMEProloguePt);
788+
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
789+
*State.AfterSMEProloguePt,
790+
State.PhysLiveRegsAfterSMEPrologue,
791+
/*IsAgnosticZA=*/IsAgnosticZA);
652792
} else {
653793
MachineBasicBlock &EntryBlock = MF.front();
654-
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
794+
emitAllocateZASaveBuffer(
795+
EntryBlock, EntryBlock.getFirstNonPHI(),
796+
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry,
797+
/*IsAgnosticZA=*/IsAgnosticZA);
655798
}
656799
}
657800

0 commit comments

Comments
 (0)