Skip to content

Commit af4a764

Browse files
committed
[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass
This extends the MachineSMEABIPass to handle agnostic ZA functions. This case is currently handled like shared ZA functions, but we don't require ZA state to be reloaded before agnostic ZA calls. Note: This patch does not yet fully handle agnostic ZA functions that can catch exceptions. E.g.: ``` __arm_agnostic("sme_za_state") void try_catch_agnostic_za_callee() { try { agnostic_za_call(); } catch(...) { noexcept_agnostic_za_call(); } } ``` As in this case, we won't commit a ZA save before the `agnostic_za_call()`, which would be needed to restore ZA in the catch block. This will be handled in a later patch. Change-Id: I9cce7b42ec8b64d5442b35231b65dfaf9d149eed
1 parent 4250bec commit af4a764

File tree

3 files changed

+332
-39
lines changed

3 files changed

+332
-39
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8154,7 +8154,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81548154
if (Subtarget->hasCustomCallingConv())
81558155
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
81568156

8157-
if (Subtarget->useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
8157+
if (Subtarget->useNewSMEABILowering()) {
81588158
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
81598159
SDValue Size;
81608160
if (Attrs.hasZAState()) {
@@ -8965,9 +8965,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89658965
bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
89668966
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
89678967
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
8968-
// TODO: Handle agnostic ZA functions.
8969-
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
8968+
if (!UseNewSMEABILowering)
8969+
return std::nullopt;
8970+
if (IsAgnosticZAFunction) {
8971+
if (CallAttrs.requiresPreservingAllZAState())
8972+
return AArch64ISD::REQUIRES_ZA_SAVE;
89708973
return std::nullopt;
8974+
}
89718975
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
89728976
return std::nullopt;
89738977
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
@@ -9047,7 +9051,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90479051
};
90489052

90499053
bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
9050-
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9054+
bool RequiresSaveAllZA =
9055+
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
90519056
if (RequiresLazySave) {
90529057
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
90539058
MachinePointerInfo MPI =

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This pass implements the SME ABI requirements for ZA state. This includes
10-
// implementing the lazy ZA state save schemes around calls.
10+
// implementing the lazy (and agnostic) ZA state save schemes around calls.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -128,7 +128,7 @@ struct MachineSMEABI : public MachineFunctionPass {
128128

129129
void collectNeededZAStates(MachineFunction &MF, SMEAttrs);
130130
void pickBundleZAStates(MachineFunction &MF);
131-
void insertStateChanges(MachineFunction &MF);
131+
void insertStateChanges(MachineFunction &MF, bool IsAgnosticZA);
132132

133133
// Emission routines for private and shared ZA functions (using lazy saves).
134134
void emitNewZAPrologue(MachineBasicBlock &MBB,
@@ -143,11 +143,46 @@ struct MachineSMEABI : public MachineFunctionPass {
143143
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
144144
bool ClearTPIDR2);
145145

146+
// Emission routines for agnostic ZA functions.
147+
void emitSetupFullZASave(MachineBasicBlock &MBB,
148+
MachineBasicBlock::iterator MBBI,
149+
LiveRegs PhysLiveRegs);
150+
void emitFullZASaveRestore(MachineBasicBlock &MBB,
151+
MachineBasicBlock::iterator MBBI,
152+
LiveRegs PhysLiveRegs, bool IsSave);
153+
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
154+
MachineBasicBlock::iterator MBBI,
155+
LiveRegs PhysLiveRegs);
156+
146157
void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
147-
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
158+
ZAState From, ZAState To, LiveRegs PhysLiveRegs,
159+
bool IsAgnosticZA);
160+
161+
// Helpers for switching between lazy/full ZA save/restore routines.
162+
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
163+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
164+
if (IsAgnosticZA)
165+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
166+
return emitSetupLazySave(MBB, MBBI);
167+
}
168+
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
169+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
170+
if (IsAgnosticZA)
171+
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
172+
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
173+
}
174+
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
175+
MachineBasicBlock::iterator MBBI,
176+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
177+
if (IsAgnosticZA)
178+
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
179+
return emitAllocateLazySaveBuffer(MBB, MBBI);
180+
}
148181

149182
TPIDR2State getTPIDR2Block(MachineFunction &MF);
150183

184+
Register getAgnosticZABufferPtr(MachineFunction &MF);
185+
151186
private:
152187
struct InstInfo {
153188
ZAState NeededState{ZAState::ANY};
@@ -158,6 +193,7 @@ struct MachineSMEABI : public MachineFunctionPass {
158193
struct BlockInfo {
159194
ZAState FixedEntryState{ZAState::ANY};
160195
SmallVector<InstInfo> Insts;
196+
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
161197
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
162198
};
163199

@@ -167,6 +203,9 @@ struct MachineSMEABI : public MachineFunctionPass {
167203
SmallVector<ZAState> BundleStates;
168204
std::optional<TPIDR2State> TPIDR2Block;
169205
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
206+
Register AgnosticZABufferPtr = AArch64::NoRegister;
207+
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
208+
bool HasFullZASaveRestore = false;
170209
} State;
171210

172211
EdgeBundles *Bundles = nullptr;
@@ -175,7 +214,8 @@ struct MachineSMEABI : public MachineFunctionPass {
175214
void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
176215
SMEAttrs SMEFnAttrs) {
177216
const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
178-
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
217+
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
218+
SMEFnAttrs.hasZAState()) &&
179219
"Expected function to have ZA/ZT0 state!");
180220

181221
State.Blocks.resize(MF.getNumBlockIDs());
@@ -209,6 +249,7 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
209249

210250
Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
211251
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
252+
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
212253
for (MachineInstr &MI : reverse(MBB)) {
213254
MachineBasicBlock::iterator MBBI(MI);
214255
LiveUnits.stepBackward(MI);
@@ -219,15 +260,20 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
219260
// block setup.
220261
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
221262
State.AfterSMEProloguePt = MBBI;
263+
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
222264
}
265+
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
223266
auto [NeededState, InsertPt] = getInstNeededZAState(
224-
TRI, MI, /*ZALiveAtReturn=*/SMEFnAttrs.hasSharedZAInterface());
267+
TRI, MI, /*ZALiveAtReturn=*/SMEFnAttrs.hasSharedZAInterface() ||
268+
SMEFnAttrs.hasAgnosticZAInterface());
225269
assert((InsertPt == MBBI ||
226270
InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
227271
"Unexpected state change insertion point!");
228272
// TODO: Do something to avoid state changes where NZCV is live.
229273
if (MBBI == FirstTerminatorInsertPt)
230274
Block.PhysLiveRegsAtExit = PhysLiveRegs;
275+
if (MBBI == FirstNonPhiInsertPt)
276+
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
231277
if (NeededState != ZAState::ANY)
232278
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
233279
}
@@ -294,7 +340,7 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
294340
}
295341
}
296342

297-
void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
343+
void MachineSMEABI::insertStateChanges(MachineFunction &MF, bool IsAgnosticZA) {
298344
for (MachineBasicBlock &MBB : MF) {
299345
BlockInfo &Block = State.Blocks[MBB.getNumber()];
300346
ZAState InState =
@@ -309,7 +355,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
309355
for (auto &Inst : Block.Insts) {
310356
if (CurrentState != Inst.NeededState)
311357
emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
312-
Inst.PhysLiveRegs);
358+
Inst.PhysLiveRegs, IsAgnosticZA);
313359
CurrentState = Inst.NeededState;
314360
}
315361

@@ -318,7 +364,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
318364

319365
if (CurrentState != OutState)
320366
emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
321-
Block.PhysLiveRegsAtExit);
367+
Block.PhysLiveRegsAtExit, IsAgnosticZA);
322368
}
323369
}
324370

@@ -571,10 +617,98 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
571617
emitZeroZA(TII, DL, MBB, MBBI, /*Mask=*/0b11111111);
572618
}
573619

620+
Register MachineSMEABI::getAgnosticZABufferPtr(MachineFunction &MF) {
621+
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
622+
return State.AgnosticZABufferPtr;
623+
if (auto BufferPtr =
624+
MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
625+
BufferPtr != AArch64::NoRegister)
626+
State.AgnosticZABufferPtr = BufferPtr;
627+
else
628+
State.AgnosticZABufferPtr =
629+
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
630+
return State.AgnosticZABufferPtr;
631+
}
632+
633+
void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
634+
MachineBasicBlock::iterator MBBI,
635+
LiveRegs PhysLiveRegs, bool IsSave) {
636+
MachineFunction &MF = *MBB.getParent();
637+
auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
638+
const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
639+
const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
640+
MachineRegisterInfo &MRI = MF.getRegInfo();
641+
642+
State.HasFullZASaveRestore = true;
643+
DebugLoc DL = getDebugLoc(MBB, MBBI);
644+
Register BufferPtr = AArch64::X0;
645+
646+
ScopedPhysRegSave ScopedPhysRegSave(MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
647+
648+
// Copy the buffer pointer into X0.
649+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), BufferPtr)
650+
.addReg(getAgnosticZABufferPtr(MF));
651+
652+
// Call __arm_sme_save/__arm_sme_restore.
653+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::BL))
654+
.addReg(BufferPtr, RegState::Implicit)
655+
.addExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore")
656+
.addRegMask(TRI.getCallPreservedMask(
657+
MF,
658+
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
659+
}
660+
661+
void MachineSMEABI::emitAllocateFullZASaveBuffer(
662+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
663+
LiveRegs PhysLiveRegs) {
664+
MachineFunction &MF = *MBB.getParent();
665+
MachineFrameInfo &MFI = MF.getFrameInfo();
666+
auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
667+
const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
668+
MachineRegisterInfo &MRI = MF.getRegInfo();
669+
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
670+
671+
// Buffer already allocated in SelectionDAG.
672+
if (AFI->getEarlyAllocSMESaveBuffer())
673+
return;
674+
675+
DebugLoc DL = getDebugLoc(MBB, MBBI);
676+
Register BufferPtr = getAgnosticZABufferPtr(MF);
677+
Register BufferSize = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
678+
679+
ScopedPhysRegSave ScopedPhysRegSave(MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
680+
681+
// Calculate the SME state size.
682+
{
683+
const AArch64RegisterInfo *TRI = Subtarget.getRegisterInfo();
684+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::BL))
685+
.addExternalSymbol("__arm_sme_state_size")
686+
.addReg(AArch64::X0, RegState::ImplicitDefine)
687+
.addRegMask(TRI->getCallPreservedMask(
688+
MF, CallingConv::
689+
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
690+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), BufferSize)
691+
.addReg(AArch64::X0);
692+
}
693+
694+
// Allocate a buffer object of the size given __arm_sme_state_size.
695+
{
696+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::SUBXrx64), AArch64::SP)
697+
.addReg(AArch64::SP)
698+
.addReg(BufferSize)
699+
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
700+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), BufferPtr)
701+
.addReg(AArch64::SP);
702+
703+
// We have just allocated a variable sized object, tell this to PEI.
704+
MFI.CreateVariableSizedObject(Align(16), nullptr);
705+
}
706+
}
707+
574708
void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
575709
MachineBasicBlock::iterator InsertPt,
576710
ZAState From, ZAState To,
577-
LiveRegs PhysLiveRegs) {
711+
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
578712

579713
// ZA not used.
580714
if (From == ZAState::ANY || To == ZAState::ANY)
@@ -601,10 +735,11 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
601735
}
602736

603737
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
604-
emitSetupLazySave(MBB, InsertPt);
738+
emitZASave(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
605739
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
606-
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
740+
emitZARestore(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
607741
else if (To == ZAState::OFF) {
742+
assert(!IsAgnosticZA && "Should not turn ZA off in agnostic ZA function");
608743
// If we're exiting from the CALLER_DORMANT state that means this new ZA
609744
// function did not touch ZA (so ZA was never turned on).
610745
if (From != ZAState::CALLER_DORMANT)
@@ -627,7 +762,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
627762

628763
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
629764
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
630-
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
765+
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
766+
!SMEFnAttrs.hasAgnosticZAInterface())
631767
return false;
632768

633769
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
@@ -636,20 +772,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
636772
State = PassState{};
637773
Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
638774

775+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
776+
639777
collectNeededZAStates(MF, SMEFnAttrs);
640778
pickBundleZAStates(MF);
641-
insertStateChanges(MF);
779+
insertStateChanges(MF, /*IsAgnosticZA=*/IsAgnosticZA);
642780

643781
// Allocate save buffer (if needed).
644-
if (State.TPIDR2Block.has_value()) {
782+
if (State.HasFullZASaveRestore || State.TPIDR2Block.has_value()) {
645783
if (State.AfterSMEProloguePt) {
646784
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
647785
// entry block (due to the probing loop).
648-
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
649-
*State.AfterSMEProloguePt);
786+
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
787+
*State.AfterSMEProloguePt,
788+
State.PhysLiveRegsAfterSMEPrologue,
789+
/*IsAgnosticZA=*/IsAgnosticZA);
650790
} else {
651791
MachineBasicBlock &EntryBlock = MF.front();
652-
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
792+
emitAllocateZASaveBuffer(
793+
EntryBlock, EntryBlock.getFirstNonPHI(),
794+
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry,
795+
/*IsAgnosticZA=*/IsAgnosticZA);
653796
}
654797
}
655798

0 commit comments

Comments
 (0)