Skip to content

[AArch64][SME] Implement the SME ABI (ZA state management) in Machine IR #149062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
FunctionPass *createAArch64CollectLOHPass();
FunctionPass *createSMEABIPass();
FunctionPass *createSMEPeepholeOptPass();
FunctionPass *createMachineSMEABIPass();
ModulePass *createSVEIntrinsicOptsPass();
InstructionSelector *
createAArch64InstructionSelector(const AArch64TargetMachine &,
Expand Down Expand Up @@ -111,6 +112,7 @@ void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
void initializeLDTLSCleanupPass(PassRegistry&);
void initializeSMEABIPass(PassRegistry &);
void initializeSMEPeepholeOptPass(PassRegistry &);
void initializeMachineSMEABIPass(PassRegistry &);
void initializeSVEIntrinsicOptsPass(PassRegistry &);
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &);
} // end namespace llvm
Expand Down
57 changes: 43 additions & 14 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI);
bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
MachineBasicBlock *expandRestoreZA(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
MachineBasicBlock *
expandCommitOrRestoreZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
};
Expand Down Expand Up @@ -990,44 +991,69 @@ bool AArch64ExpandPseudo::expandStoreSwiftAsyncContext(
return true;
}

MachineBasicBlock *
AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;

MachineBasicBlock *AArch64ExpandPseudo::expandCommitOrRestoreZASave(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
MachineInstr &MI = *MBBI;
bool IsRestoreZA = MI.getOpcode() == AArch64::RestoreZAPseudo;
assert((MI.getOpcode() == AArch64::RestoreZAPseudo ||
MI.getOpcode() == AArch64::CommitZASavePseudo) &&
"Expected ZA commit or restore");
assert((std::next(MBBI) != MBB.end() ||
MI.getParent()->successors().begin() !=
MI.getParent()->successors().end()) &&
"Unexpected unreachable in block that restores ZA");

// Compare TPIDR2_EL0 value against 0.
DebugLoc DL = MI.getDebugLoc();
MachineInstrBuilder Cbz = BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX))
.add(MI.getOperand(0));
MachineInstrBuilder Branch =
BuildMI(MBB, MBBI, DL,
TII->get(IsRestoreZA ? AArch64::CBZX : AArch64::CBNZX))
.add(MI.getOperand(0));

// Split MBB and create two new blocks:
// - MBB now contains all instructions before RestoreZAPseudo.
// - SMBB contains the RestoreZAPseudo instruction only.
// - EndBB contains all instructions after RestoreZAPseudo.
// - SMBB contains the [Commit|RestoreZA]Pseudo instruction only.
// - EndBB contains all instructions after [Commit|RestoreZA]Pseudo.
MachineInstr &PrevMI = *std::prev(MBBI);
MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
? *SMBB->successors().begin()
: SMBB->splitAt(MI, /*UpdateLiveIns*/ true);

// Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB.
Cbz.addMBB(SMBB);
// Add the SMBB label to the CB[N]Z instruction & create a branch to EndBB.
Branch.addMBB(SMBB);
BuildMI(&MBB, DL, TII->get(AArch64::B))
.addMBB(EndBB);
MBB.addSuccessor(EndBB);

// Replace the pseudo with a call (BL).
MachineInstrBuilder MIB =
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL));
MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
// Copy operands (mainly the regmask) from the pseudo.
for (unsigned I = 2; I < MI.getNumOperands(); ++I)
MIB.add(MI.getOperand(I));
BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);

if (IsRestoreZA) {
// Mark the TPIDR2 block pointer (X0) as an implicit use.
MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
} else /*CommitZA*/ {
auto *TRI = MBB.getParent()->getSubtarget().getRegisterInfo();
// Clear TPIDR2_EL0.
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::MSR))
.addImm(AArch64SysReg::TPIDR2_EL0)
.addReg(AArch64::XZR);
bool ZeroZA = MI.getOperand(1).getImm() != 0;
if (ZeroZA) {
assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!");
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::ZERO_M))
.addImm(ZERO_ALL_ZA_MASK)
.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
}
}

BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
MI.eraseFromParent();
return EndBB;
}
Expand Down Expand Up @@ -1646,8 +1672,9 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
return expandCALL_BTI(MBB, MBBI);
case AArch64::StoreSwiftAsyncContext:
return expandStoreSwiftAsyncContext(MBB, MBBI);
case AArch64::CommitZASavePseudo:
case AArch64::RestoreZAPseudo: {
auto *NewMBB = expandRestoreZA(MBB, MBBI);
auto *NewMBB = expandCommitOrRestoreZASave(MBB, MBBI);
if (NewMBB != &MBB)
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
Expand All @@ -1658,6 +1685,8 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
}
case AArch64::InOutZAUsePseudo:
case AArch64::RequiresZASavePseudo:
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
case AArch64::COALESCER_BARRIER_FPR64:
Expand Down
149 changes: 91 additions & 58 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "AArch64PerfectShuffle.h"
#include "AArch64RegisterInfo.h"
#include "AArch64Subtarget.h"
#include "AArch64TargetMachine.h"
#include "MCTargetDesc/AArch64AddressingModes.h"
#include "Utils/AArch64BaseInfo.h"
#include "Utils/AArch64SMEAttributes.h"
Expand Down Expand Up @@ -1998,6 +1999,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::f16, Promote);
}

const AArch64TargetMachine &AArch64TargetLowering::getTM() const {
return static_cast<const AArch64TargetMachine &>(getTargetMachine());
}

void AArch64TargetLowering::addTypeForNEON(MVT VT) {
assert(VT.isVector() && "VT should be a vector type");

Expand Down Expand Up @@ -8284,53 +8289,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
if (Subtarget->hasCustomCallingConv())
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);

// Create a 16 Byte TPIDR2 object. The dynamic buffer
// will be expanded and stored in the static object later using a pseudonode.
if (Attrs.hasZAState()) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
DAG.getConstant(1, DL, MVT::i32));

SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
} else {
SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
DAG.getVTList(MVT::i64, MVT::Other),
{Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
} else if (Attrs.hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
DAG.getVTList(MVT::i64, MVT::Other), Chain);
Chain = BufferSize.getValue(1);

SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer =
DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
} else {
// Allocate space dynamically.
Buffer = DAG.getNode(
ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
{Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
if (!getTM().useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
// Old SME ABI lowering (deprecated):
// Create a 16 Byte TPIDR2 object. The dynamic buffer
// will be expanded and stored in the static object later using a
// pseudonode.
if (Attrs.hasZAState()) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
DAG.getConstant(1, DL, MVT::i32));
SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
} else {
SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
DAG.getVTList(MVT::i64, MVT::Other),
{Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
} else if (Attrs.hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
DAG.getVTList(MVT::i64, MVT::Other), Chain);
Chain = BufferSize.getValue(1);
SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other),
{Chain, BufferSize});
} else {
// Allocate space dynamically.
Buffer = DAG.getNode(
ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
{Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}
// Copy the value to a virtual register, and save that in FuncInfo.
Register BufferPtr =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
FuncInfo->setSMESaveBufferAddr(BufferPtr);
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}

// Copy the value to a virtual register, and save that in FuncInfo.
Register BufferPtr =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
FuncInfo->setSMESaveBufferAddr(BufferPtr);
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}

if (CallConv == CallingConv::PreserveNone) {
Expand All @@ -8347,6 +8353,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
}
}

if (getTM().useNewSMEABILowering()) {
// Clear new ZT0 state. TODO: Move this to the SME ABI pass.
if (Attrs.isNewZT0())
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
DAG.getTargetConstant(0, DL, MVT::i32));
}

return Chain;
}

Expand Down Expand Up @@ -8918,7 +8933,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setSMESaveBufferUsed();

TargetLowering::ArgListTy Args;
Args.emplace_back(
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
Expand Down Expand Up @@ -9059,14 +9073,28 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
CallConv = CallingConv::AArch64_SVE_VectorCall;
}

// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
// TODO: Handle agnostic ZA functions.
if (!UseNewSMEABILowering || IsAgnosticZAFunction)
return std::nullopt;
if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
return std::nullopt;
return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
: AArch64ISD::INOUT_ZA_USE;
}();

if (IsTailCall) {
// Check if it's really possible to do a tail call.
IsTailCall = isEligibleForTailCallOptimization(CLI);

// A sibling call is one where we're under the usual C ABI and not planning
// to change that but can still do a tail call:
if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
CallConv != CallingConv::SwiftTail)
if (!ZAMarkerNode && !TailCallOpt && IsTailCall &&
CallConv != CallingConv::Tail && CallConv != CallingConv::SwiftTail)
IsSibCall = true;

if (IsTailCall)
Expand Down Expand Up @@ -9118,9 +9146,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
}

// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);

auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
Expand All @@ -9134,7 +9159,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return R;
};

bool RequiresLazySave = CallAttrs.requiresLazySave();
bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
Expand Down Expand Up @@ -9209,10 +9234,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));

// Adjust the stack pointer for the new arguments...
// Adjust the stack pointer for the new arguments... and mark ZA uses.
// These operations are automatically eliminated by the prolog/epilog pass
if (!IsSibCall)
assert((!IsSibCall || !ZAMarkerNode) && "ZA markers require CALLSEQ_START");
if (!IsSibCall) {
Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
if (ZAMarkerNode) {
// Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
// using a chain can result in incorrect scheduling. The markers refer to
// the position just before the CALLSEQ_START (though occur after as
// CALLSEQ_START lacks in-glue).
Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
{Chain, Chain.getValue(1)});
}
}

SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
getPointerTy(DAG.getDataLayout()));
Expand Down Expand Up @@ -9683,7 +9718,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
}

if (CallAttrs.requiresEnablingZAAfterCall())
if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
Expand All @@ -9705,7 +9740,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue TPIDR2_EL0 = DAG.getNode(
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));

// Copy the address of the TPIDR2 block into X0 before 'calling' the
// RESTORE_ZA pseudo.
SDValue Glue;
Expand All @@ -9717,7 +9751,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
{Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
RestoreRoutine, RegMask, Result.getValue(1)});

// Finally reset the TPIDR2_EL0 register to 0.
Result = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

namespace llvm {

class AArch64TargetMachine;

namespace AArch64 {
/// Possible values of current rounding mode, which is specified in bits
/// 23:22 of FPCR.
Expand Down Expand Up @@ -64,6 +66,8 @@ class AArch64TargetLowering : public TargetLowering {
explicit AArch64TargetLowering(const TargetMachine &TM,
const AArch64Subtarget &STI);

const AArch64TargetMachine &getTM() const;

/// Control the following reassociation of operands: (op (op x, c1), y) -> (op
/// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
Expand Down Expand Up @@ -173,6 +177,10 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *EmitZTInstr(MachineInstr &MI, MachineBasicBlock *BB,
unsigned Opcode, bool Op0IsDef) const;
MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const;

// Note: The following group of functions are only used as part of the old SME
// ABI lowering. They will be removed once -aarch64-new-sme-abi=true is the
// default.
MachineBasicBlock *EmitInitTPIDR2Object(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
Expand Down
Loading