Skip to content

[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass #149064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: users/MacDue/windows-sme
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8292,7 +8292,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
if (Subtarget->hasCustomCallingConv())
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);

if (getTM().useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
if (getTM().useNewSMEABILowering()) {
if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
SDValue Size;
if (Attrs.hasZAState()) {
Expand Down Expand Up @@ -9113,9 +9113,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
// TODO: Handle agnostic ZA functions.
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
if (!UseNewSMEABILowering)
return std::nullopt;
if (IsAgnosticZAFunction) {
if (CallAttrs.requiresPreservingAllZAState())
return AArch64ISD::REQUIRES_ZA_SAVE;
return std::nullopt;
}
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
return std::nullopt;
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
Expand Down Expand Up @@ -9195,7 +9199,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};

bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
bool RequiresSaveAllZA =
!UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
Expand Down
171 changes: 155 additions & 16 deletions llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This pass implements the SME ABI requirements for ZA state. This includes
// implementing the lazy ZA state save schemes around calls.
// implementing the lazy (and agnostic) ZA state save schemes around calls.
//
//===----------------------------------------------------------------------===//
//
Expand Down Expand Up @@ -200,7 +200,7 @@ struct MachineSMEABI : public MachineFunctionPass {

/// Inserts code to handle changes between ZA states within the function.
/// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
void insertStateChanges();
void insertStateChanges(bool IsAgnosticZA);

// Emission routines for private and shared ZA functions (using lazy saves).
void emitNewZAPrologue(MachineBasicBlock &MBB,
Expand All @@ -215,8 +215,41 @@ struct MachineSMEABI : public MachineFunctionPass {
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2);

// Emission routines for agnostic ZA functions.
void emitSetupFullZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs);
void emitFullZASaveRestore(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsSave);
void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs);

void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
ZAState From, ZAState To, LiveRegs PhysLiveRegs);
ZAState From, ZAState To, LiveRegs PhysLiveRegs,
bool IsAgnosticZA);

// Helpers for switching between lazy/full ZA save/restore routines.
void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
if (IsAgnosticZA)
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
return emitSetupLazySave(MBB, MBBI);
}
void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
if (IsAgnosticZA)
return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
}
void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
if (IsAgnosticZA)
return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
return emitAllocateLazySaveBuffer(MBB, MBBI);
}

/// Save live physical registers to virtual registers.
PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
Expand All @@ -228,6 +261,8 @@ struct MachineSMEABI : public MachineFunctionPass {
/// Get or create a TPIDR2 block in this function.
TPIDR2State getTPIDR2Block();

Register getAgnosticZABufferPtr();

private:
/// Contains the needed ZA state (and live registers) at an instruction.
struct InstInfo {
Expand All @@ -241,6 +276,7 @@ struct MachineSMEABI : public MachineFunctionPass {
struct BlockInfo {
ZAState FixedEntryState{ZAState::ANY};
SmallVector<InstInfo> Insts;
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
};

Expand All @@ -250,6 +286,9 @@ struct MachineSMEABI : public MachineFunctionPass {
SmallVector<ZAState> BundleStates;
std::optional<TPIDR2State> TPIDR2Block;
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
Register AgnosticZABufferPtr = AArch64::NoRegister;
LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
bool HasFullZASaveRestore = false;
} State;

MachineFunction *MF = nullptr;
Expand All @@ -261,7 +300,8 @@ struct MachineSMEABI : public MachineFunctionPass {
};

void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
SMEFnAttrs.hasZAState()) &&
"Expected function to have ZA/ZT0 state!");

State.Blocks.resize(MF->getNumBlockIDs());
Expand Down Expand Up @@ -295,6 +335,7 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {

Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
for (MachineInstr &MI : reverse(MBB)) {
MachineBasicBlock::iterator MBBI(MI);
LiveUnits.stepBackward(MI);
Expand All @@ -305,7 +346,9 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
// block setup.
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
State.AfterSMEProloguePt = MBBI;
State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
}
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
auto [NeededState, InsertPt] = getZAStateBeforeInst(
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
assert((InsertPt == MBBI ||
Expand All @@ -314,6 +357,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
// TODO: Do something to avoid state changes where NZCV is live.
if (MBBI == FirstTerminatorInsertPt)
Block.PhysLiveRegsAtExit = PhysLiveRegs;
if (MBBI == FirstNonPhiInsertPt)
Block.PhysLiveRegsAtEntry = PhysLiveRegs;
if (NeededState != ZAState::ANY)
Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
}
Expand Down Expand Up @@ -380,7 +425,7 @@ void MachineSMEABI::assignBundleZAStates() {
}
}

void MachineSMEABI::insertStateChanges() {
void MachineSMEABI::insertStateChanges(bool IsAgnosticZA) {
for (MachineBasicBlock &MBB : *MF) {
const BlockInfo &Block = State.Blocks[MBB.getNumber()];
ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(),
Expand All @@ -393,7 +438,7 @@ void MachineSMEABI::insertStateChanges() {
for (auto &Inst : Block.Insts) {
if (CurrentState != Inst.NeededState)
emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
Inst.PhysLiveRegs);
Inst.PhysLiveRegs, IsAgnosticZA);
CurrentState = Inst.NeededState;
}

Expand All @@ -404,7 +449,7 @@ void MachineSMEABI::insertStateChanges() {
State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
if (CurrentState != OutState)
emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
Block.PhysLiveRegsAtExit);
Block.PhysLiveRegsAtExit, IsAgnosticZA);
}
}

Expand Down Expand Up @@ -618,10 +663,95 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
.addImm(1);
}

Register MachineSMEABI::getAgnosticZABufferPtr() {
if (State.AgnosticZABufferPtr != AArch64::NoRegister)
return State.AgnosticZABufferPtr;
if (auto BufferPtr =
MF->getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
BufferPtr != AArch64::NoRegister)
State.AgnosticZABufferPtr = BufferPtr;
else
State.AgnosticZABufferPtr =
MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
return State.AgnosticZABufferPtr;
}

void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsSave) {
auto *TLI = Subtarget->getTargetLowering();
State.HasFullZASaveRestore = true;
DebugLoc DL = getDebugLoc(MBB, MBBI);
Register BufferPtr = AArch64::X0;

PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);

// Copy the buffer pointer into X0.
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
.addReg(getAgnosticZABufferPtr());

// Call __arm_sme_save/__arm_sme_restore.
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
.addReg(BufferPtr, RegState::Implicit)
.addExternalSymbol(TLI->getLibcallName(
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
.addRegMask(TRI->getCallPreservedMask(
*MF,
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));

restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitAllocateFullZASaveBuffer(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
auto *AFI = MF->getInfo<AArch64FunctionInfo>();

// Buffer already allocated in SelectionDAG.
if (AFI->getEarlyAllocSMESaveBuffer())
return;

DebugLoc DL = getDebugLoc(MBB, MBBI);
Register BufferPtr = getAgnosticZABufferPtr();
Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);

PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);

// Calculate the SME state size.
{
auto *TLI = Subtarget->getTargetLowering();
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
.addReg(AArch64::X0);
}

// Allocate a buffer object of the size given __arm_sme_state_size.
{
MachineFrameInfo &MFI = MF->getFrameInfo();
BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
.addReg(AArch64::SP)
.addReg(BufferSize)
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
.addReg(AArch64::SP);

// We have just allocated a variable sized object, tell this to PEI.
MFI.CreateVariableSizedObject(Align(16), nullptr);
}

restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
MachineBasicBlock::iterator InsertPt,
ZAState From, ZAState To,
LiveRegs PhysLiveRegs) {
LiveRegs PhysLiveRegs, bool IsAgnosticZA) {

// ZA not used.
if (From == ZAState::ANY || To == ZAState::ANY)
Expand Down Expand Up @@ -653,12 +783,13 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
}

if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
emitSetupLazySave(MBB, InsertPt);
emitZASave(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
emitZARestore(MBB, InsertPt, PhysLiveRegs, IsAgnosticZA);
else if (To == ZAState::OFF) {
assert(From != ZAState::CALLER_DORMANT &&
"CALLER_DORMANT to OFF should have already been handled");
assert(!IsAgnosticZA && "Should not turn ZA off in agnostic ZA function");
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
} else {
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
Expand All @@ -678,7 +809,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {

auto *AFI = MF.getInfo<AArch64FunctionInfo>();
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
!SMEFnAttrs.hasAgnosticZAInterface())
return false;

assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
Expand All @@ -692,20 +824,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
TRI = Subtarget->getRegisterInfo();
MRI = &MF.getRegInfo();

bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();

collectNeededZAStates(SMEFnAttrs);
assignBundleZAStates();
insertStateChanges();
insertStateChanges(/*IsAgnosticZA=*/IsAgnosticZA);

// Allocate save buffer (if needed).
if (State.TPIDR2Block) {
if (State.HasFullZASaveRestore || State.TPIDR2Block) {
if (State.AfterSMEProloguePt) {
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
// entry block (due to the probing loop).
emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
*State.AfterSMEProloguePt);
emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
*State.AfterSMEProloguePt,
State.PhysLiveRegsAfterSMEPrologue,
/*IsAgnosticZA=*/IsAgnosticZA);
} else {
MachineBasicBlock &EntryBlock = MF.front();
emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
emitAllocateZASaveBuffer(
EntryBlock, EntryBlock.getFirstNonPHI(),
State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry,
/*IsAgnosticZA=*/IsAgnosticZA);
}
}

Expand Down
Loading