From faa2fe9e11ce5b04ac8711ffe4c381bd66691943 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 7 Aug 2025 13:40:49 +0000 Subject: [PATCH 1/9] [AArch64][SME] Port all SME routines to RuntimeLibcalls This updates everywhere we emit/check an SME routines to use RuntimeLibcalls to get the function name and calling convention. Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so tweaked the output to avoid the need for variables. --- llvm/include/llvm/CodeGen/TargetLowering.h | 6 +++ llvm/include/llvm/IR/RuntimeLibcalls.td | 43 ++++++++++++++++++- llvm/include/llvm/IR/RuntimeLibcallsImpl.td | 3 ++ .../Target/AArch64/AArch64FrameLowering.cpp | 16 ++++--- .../Target/AArch64/AArch64ISelLowering.cpp | 40 ++++++++--------- .../AArch64/AArch64TargetTransformInfo.cpp | 18 +++----- llvm/lib/Target/AArch64/SMEABIPass.cpp | 31 +++++++++---- .../AArch64/Utils/AArch64SMEAttributes.cpp | 39 ++++++++++++----- .../AArch64/Utils/AArch64SMEAttributes.h | 18 ++++---- .../Target/AArch64/SMEAttributesTest.cpp | 2 +- 10 files changed, 151 insertions(+), 65 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index ed7495694cc70..a7352de2ad673 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3566,6 +3566,12 @@ class LLVM_ABI TargetLoweringBase { return Libcalls.getMemcpyName().data(); } + /// Check if this is valid libcall for the current module, otherwise + /// RTLIB::Unsupported. + RTLIB::LibcallImpl getSupportedLibcallImpl(StringRef FuncName) const { + return Libcalls.getSupportedLibcallImpl(FuncName); + } + /// Get the comparison predicate that's to be used to test the result of the /// comparison libcall against zero. This should only be used with /// floating-point compare libcalls. diff --git a/llvm/include/llvm/IR/RuntimeLibcalls.td b/llvm/include/llvm/IR/RuntimeLibcalls.td index 9072a0aa1531f..9626004cbed42 100644 --- a/llvm/include/llvm/IR/RuntimeLibcalls.td +++ b/llvm/include/llvm/IR/RuntimeLibcalls.td @@ -406,6 +406,17 @@ multiclass LibmLongDoubleLibCall AArch64LibcallImpls = { def __arm_sc_memcpy : RuntimeLibcallImpl; def __arm_sc_memmove : RuntimeLibcallImpl; def __arm_sc_memset : RuntimeLibcallImpl; + def __arm_sc_memchr : RuntimeLibcallImpl; } // End AArch64LibcallImpls +def __arm_sme_state : RuntimeLibcallImpl; +def __arm_tpidr2_save : RuntimeLibcallImpl; +def __arm_za_disable : RuntimeLibcallImpl; +def __arm_tpidr2_restore : RuntimeLibcallImpl; +def __arm_get_current_vg : RuntimeLibcallImpl; +def __arm_sme_state_size : RuntimeLibcallImpl; +def __arm_sme_save : RuntimeLibcallImpl; +def __arm_sme_restore : RuntimeLibcallImpl; + +def SMEABI_LibCalls_PreserveMost_From_X0 : LibcallsWithCC<(add + __arm_tpidr2_save, + __arm_za_disable, + __arm_tpidr2_restore), + SMEABI_PreserveMost_From_X0>; + +def SMEABI_LibCalls_PreserveMost_From_X1 : LibcallsWithCC<(add + __arm_get_current_vg, + __arm_sme_state_size, + __arm_sme_save, + __arm_sme_restore), + SMEABI_PreserveMost_From_X1>; + +def SMEABI_LibCalls_PreserveMost_From_X2 : LibcallsWithCC<(add + __arm_sme_state), + SMEABI_PreserveMost_From_X2>; + def isAArch64_ExceptArm64EC : RuntimeLibcallPredicate<"(TT.isAArch64() && !TT.isWindowsArm64EC())">; def isWindowsArm64EC : RuntimeLibcallPredicate<"TT.isWindowsArm64EC()">; @@ -1244,7 +1282,10 @@ def AArch64SystemLibrary : SystemRuntimeLibrary< LibmHasSinCosF32, LibmHasSinCosF64, LibmHasSinCosF128, DefaultLibmExp10, DefaultStackProtector, - SecurityCheckCookieIfWinMSVC) + SecurityCheckCookieIfWinMSVC, + SMEABI_LibCalls_PreserveMost_From_X0, + SMEABI_LibCalls_PreserveMost_From_X1, + SMEABI_LibCalls_PreserveMost_From_X2) >; // Prepend a # to every name diff --git a/llvm/include/llvm/IR/RuntimeLibcallsImpl.td b/llvm/include/llvm/IR/RuntimeLibcallsImpl.td index 601c291daf89d..b5752c1b69ad8 100644 --- a/llvm/include/llvm/IR/RuntimeLibcallsImpl.td +++ b/llvm/include/llvm/IR/RuntimeLibcallsImpl.td @@ -36,6 +36,9 @@ def ARM_AAPCS : LibcallCallingConv<[{CallingConv::ARM_AAPCS}]>; def ARM_AAPCS_VFP : LibcallCallingConv<[{CallingConv::ARM_AAPCS_VFP}]>; def X86_STDCALL : LibcallCallingConv<[{CallingConv::X86_StdCall}]>; def AVR_BUILTIN : LibcallCallingConv<[{CallingConv::AVR_BUILTIN}]>; +def SMEABI_PreserveMost_From_X0 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0}]>; +def SMEABI_PreserveMost_From_X1 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1}]>; +def SMEABI_PreserveMost_From_X2 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2}]>; /// Abstract definition for functionality the compiler may need to /// emit a call to. Emits the RTLIB::Libcall enum - This enum defines diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index 885f2a94f85f5..ba02c82b25aaf 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1487,8 +1487,11 @@ bool isVGInstruction(MachineBasicBlock::iterator MBBI) { if (Opc == AArch64::BL) { auto Op1 = MBBI->getOperand(0); - return Op1.isSymbol() && - (StringRef(Op1.getSymbolName()) == "__arm_get_current_vg"); + auto &TLI = + *MBBI->getMF()->getSubtarget().getTargetLowering(); + char const *GetCurrentVG = + TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG); + return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG; } } @@ -3468,6 +3471,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, ArrayRef CSI, const TargetRegisterInfo *TRI) const { MachineFunction &MF = *MBB.getParent(); + auto &TLI = *MF.getSubtarget().getTargetLowering(); const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); AArch64FunctionInfo *AFI = MF.getInfo(); bool NeedsWinCFI = needsWinCFI(MF); @@ -3581,11 +3585,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( .addReg(AArch64::X0, RegState::Implicit) .setMIFlag(MachineInstr::FrameSetup); - const uint32_t *RegMask = TRI->getCallPreservedMask( - MF, - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1); + RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG; + const uint32_t *RegMask = + TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC)); BuildMI(MBB, MI, DL, TII.get(AArch64::BL)) - .addExternalSymbol("__arm_get_current_vg") + .addExternalSymbol(TLI.getLibcallName(LC)) .addRegMask(RegMask) .addReg(AArch64::X0, RegState::ImplicitDefine) .setMIFlag(MachineInstr::FrameSetup); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2072e48914ae6..556b04c2be323 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3083,13 +3083,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, AArch64FunctionInfo *FuncInfo = MF->getInfo(); const TargetInstrInfo *TII = Subtarget->getInstrInfo(); if (FuncInfo->isSMESaveBufferUsed()) { + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) - .addExternalSymbol("__arm_sme_state_size") + .addExternalSymbol(getLibcallName(LC)) .addReg(AArch64::X0, RegState::ImplicitDefine) - .addRegMask(TRI->getCallPreservedMask( - *MF, CallingConv:: - AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); + .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC))); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), MI.getOperand(0).getReg()) .addReg(AArch64::X0); @@ -5739,15 +5738,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) { SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL, EVT VT) const { - SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE; + SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC), getPointerTy(DAG.getDataLayout())); Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); Type *RetTy = StructType::get(Int64Ty, Int64Ty); TargetLowering::CallLoweringInfo CLI(DAG); ArgListTy Args; CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2, - RetTy, Callee, std::move(Args)); + getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); std::pair CallResult = LowerCallTo(CLI); SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), @@ -8600,12 +8599,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, } static SMECallAttrs -getSMECallAttrs(const Function &Caller, +getSMECallAttrs(const Function &Caller, const TargetLowering &TLI, const TargetLowering::CallLoweringInfo &CLI) { if (CLI.CB) - return SMECallAttrs(*CLI.CB); + return SMECallAttrs(*CLI.CB, &TLI); if (auto *ES = dyn_cast(CLI.Callee)) - return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol())); + return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI)); return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal)); } @@ -8627,7 +8626,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // SME Streaming functions are not eligible for TCO as they may require // the streaming mode or ZA to be restored after returning from the call. - SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI); + SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState() || CallAttrs.caller().hasStreamingBody()) @@ -8921,14 +8920,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64); Args.push_back(Entry); - SDValue Callee = - DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore", - TLI.getPointerTy(DAG.getDataLayout())); + RTLIB::Libcall LC = + IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE; + SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC), + TLI.getPointerTy(DAG.getDataLayout())); auto *RetTy = Type::getVoidTy(*DAG.getContext()); TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy, - Callee, std::move(Args)); + TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)); return TLI.LowerCallTo(CLI).second; } @@ -9116,7 +9115,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } // Determine whether we need any streaming mode changes. - SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI); + SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI); auto DescribeCallsite = [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & { @@ -9693,11 +9692,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (RequiresLazySave) { // Conditionally restore the lazy save using a pseudo node. + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE; TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); SDValue RegMask = DAG.getRegisterMask( - TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC))); SDValue RestoreRoutine = DAG.getTargetExternalSymbol( - "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout())); + getLibcallName(LC), getPointerTy(DAG.getDataLayout())); SDValue TPIDR2_EL0 = DAG.getNode( ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); @@ -29036,7 +29036,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { // Checks to allow the use of SME instructions if (auto *Base = dyn_cast(&Inst)) { - auto CallAttrs = SMECallAttrs(*Base); + auto CallAttrs = SMECallAttrs(*Base, this); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 3042251cf754d..adf4b33b8021a 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -220,20 +220,16 @@ static cl::opt EnableFixedwidthAutovecInStreamingMode( static cl::opt EnableScalableAutovecInStreamingMode( "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden); -static bool isSMEABIRoutineCall(const CallInst &CI) { +static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) { const auto *F = CI.getCalledFunction(); - return F && StringSwitch(F->getName()) - .Case("__arm_sme_state", true) - .Case("__arm_tpidr2_save", true) - .Case("__arm_tpidr2_restore", true) - .Case("__arm_za_disable", true) - .Default(false); + return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine(); } /// Returns true if the function has explicit operations that can only be /// lowered using incompatible instructions for the selected mode. This also /// returns true if the function F may use or modify ZA state. -static bool hasPossibleIncompatibleOps(const Function *F) { +static bool hasPossibleIncompatibleOps(const Function *F, + const TargetLowering &TLI) { for (const BasicBlock &BB : *F) { for (const Instruction &I : BB) { // Be conservative for now and assume that any call to inline asm or to @@ -242,7 +238,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) { // all native LLVM instructions can be lowered to compatible instructions. if (isa(I) && !I.isDebugOrPseudoInst() && (cast(I).isInlineAsm() || isa(I) || - isSMEABIRoutineCall(cast(I)))) + isSMEABIRoutineCall(cast(I), TLI))) return true; } } @@ -290,7 +286,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) { - if (hasPossibleIncompatibleOps(Callee)) + if (hasPossibleIncompatibleOps(Callee, *getTLI())) return false; } @@ -357,7 +353,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, // change only once and avoid inlining of G into F. SMEAttrs FAttrs(*F); - SMECallAttrs CallAttrs(Call); + SMECallAttrs CallAttrs(Call, getTLI()); if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) { if (F == Call.getCaller()) // (1) diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp index 4af4d49306625..2008516885c35 100644 --- a/llvm/lib/Target/AArch64/SMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -15,11 +15,16 @@ #include "AArch64.h" #include "Utils/AArch64SMEAttributes.h" #include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/RuntimeLibcalls.h" +#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; @@ -33,9 +38,13 @@ struct SMEABI : public FunctionPass { bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + } + private: bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder, - SMEAttrs FnAttrs); + SMEAttrs FnAttrs, const TargetLowering &TLI); }; } // end anonymous namespace @@ -51,14 +60,16 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); } //===----------------------------------------------------------------------===// // Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0. -void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { +void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, const TargetLowering &TLI, + bool ZT0IsUndef = false) { auto &Ctx = M->getContext(); auto *TPIDR2SaveTy = FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false); auto Attrs = AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible"); + RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_SAVE; FunctionCallee Callee = - M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs); + M->getOrInsertFunction(TLI.getLibcallName(LC), TPIDR2SaveTy, Attrs); CallInst *Call = Builder.CreateCall(Callee); // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark @@ -67,8 +78,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { if (ZT0IsUndef) Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef")); - Call->setCallingConv( - CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0); + Call->setCallingConv(TLI.getLibcallCallingConv(LC)); // A save to TPIDR2 should be followed by clearing TPIDR2_EL0. Function *WriteIntr = @@ -98,7 +108,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { /// interface if it does not share ZA or ZT0. /// bool SMEABI::updateNewStateFunctions(Module *M, Function *F, - IRBuilder<> &Builder, SMEAttrs FnAttrs) { + IRBuilder<> &Builder, SMEAttrs FnAttrs, + const TargetLowering &TLI) { LLVMContext &Context = F->getContext(); BasicBlock *OrigBB = &F->getEntryBlock(); Builder.SetInsertPoint(&OrigBB->front()); @@ -124,7 +135,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F, // Create a call __arm_tpidr2_save, which commits the lazy save. Builder.SetInsertPoint(&SaveBB->back()); - emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); + emitTPIDR2Save(M, Builder, TLI, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); // Enable pstate.za at the start of the function. Builder.SetInsertPoint(&OrigBB->front()); @@ -172,10 +183,14 @@ bool SMEABI::runOnFunction(Function &F) { if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za")) return false; + const TargetMachine &TM = + getAnalysis().getTM(); + const TargetLowering &TLI = *TM.getSubtargetImpl(F)->getTargetLowering(); + bool Changed = false; SMEAttrs FnAttrs(F); if (FnAttrs.isNewZA() || FnAttrs.isNewZT0()) - Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs); + Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs, TLI); return Changed; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 271094f935e0e..bb788fcebe4ae 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "AArch64SMEAttributes.h" +#include "llvm/CodeGen/TargetLowering.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/RuntimeLibcalls.h" #include using namespace llvm; @@ -77,19 +79,36 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= encodeZT0State(StateValue::New); } -void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) { +void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, + const TargetLowering &TLI) { + RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); + if (Impl == RTLIB::Unsupported) + return; + RTLIB::Libcall LC = RTLIB::RuntimeLibcallsInfo::getLibcallFromImpl(Impl); unsigned KnownAttrs = SMEAttrs::Normal; - if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state") + switch (LC) { + case RTLIB::SMEABI_SME_STATE: + case RTLIB::SMEABI_TPIDR2_SAVE: + case RTLIB::SMEABI_GET_CURRENT_VG: + case RTLIB::SMEABI_SME_STATE_SIZE: + case RTLIB::SMEABI_SME_SAVE: + case RTLIB::SMEABI_SME_RESTORE: KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine); - if (FuncName == "__arm_tpidr2_restore") + break; + case RTLIB::SMEABI_ZA_DISABLE: + case RTLIB::SMEABI_TPIDR2_RESTORE: KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) | SMEAttrs::SME_ABI_Routine; - if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" || - FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr") + break; + case RTLIB::SC_MEMCPY: + case RTLIB::SC_MEMMOVE: + case RTLIB::SC_MEMSET: + case RTLIB::SC_MEMCHR: KnownAttrs |= SMEAttrs::SM_Compatible; - if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" || - FuncName == "__arm_sme_state_size") - KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; + break; + default: + break; + } set(KnownAttrs); } @@ -110,11 +129,11 @@ bool SMECallAttrs::requiresSMChange() const { return true; } -SMECallAttrs::SMECallAttrs(const CallBase &CB) +SMECallAttrs::SMECallAttrs(const CallBase &CB, const TargetLowering *TLI) : CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal), Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) { if (auto *CalledFunction = CB.getCalledFunction()) - CalledFn = SMEAttrs(*CalledFunction, SMEAttrs::InferAttrsFromName::Yes); + CalledFn = SMEAttrs(*CalledFunction, TLI); // FIXME: We probably should not allow SME attributes on direct calls but // clang duplicates streaming mode attributes at each callsite. diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index f1be0ecbee7ed..06376c74025f8 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -13,6 +13,8 @@ namespace llvm { +class TargetLowering; + class Function; class CallBase; class AttributeList; @@ -48,17 +50,17 @@ class SMEAttrs { CallSiteFlags_Mask = ZT0_Undef }; - enum class InferAttrsFromName { No, Yes }; - SMEAttrs() = default; SMEAttrs(unsigned Mask) { set(Mask); } - SMEAttrs(const Function &F, InferAttrsFromName Infer = InferAttrsFromName::No) + SMEAttrs(const Function &F, const TargetLowering *TLI = nullptr) : SMEAttrs(F.getAttributes()) { - if (Infer == InferAttrsFromName::Yes) - addKnownFunctionAttrs(F.getName()); + if (TLI) + addKnownFunctionAttrs(F.getName(), *TLI); } SMEAttrs(const AttributeList &L); - SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); }; + SMEAttrs(StringRef FuncName, const TargetLowering &TLI) { + addKnownFunctionAttrs(FuncName, TLI); + }; void set(unsigned M, bool Enable = true); @@ -146,7 +148,7 @@ class SMEAttrs { } private: - void addKnownFunctionAttrs(StringRef FuncName); + void addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI); }; /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has @@ -163,7 +165,7 @@ class SMECallAttrs { SMEAttrs Callsite = SMEAttrs::Normal) : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} - SMECallAttrs(const CallBase &CB); + SMECallAttrs(const CallBase &CB, const TargetLowering *TLI); SMEAttrs &caller() { return CallerFn; } SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; } diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp index f13252f3a4c28..e90f733d79fca 100644 --- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp +++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp @@ -78,7 +78,7 @@ TEST(SMEAttributes, Constructors) { "ret void\n}"); CallBase &Call = cast((CallModule->getFunction("foo")->begin()->front())); - ASSERT_TRUE(SMECallAttrs(Call).callsite().hasUndefZT0()); + ASSERT_TRUE(SMECallAttrs(Call, nullptr).callsite().hasUndefZT0()); // Invalid combinations. EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible), From 97b3cf5d96e028d774637b54e9f8799f02ee66d5 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 12 Aug 2025 08:47:01 +0000 Subject: [PATCH 2/9] Rebase: Use RuntimeLibcalls for EmitEntryPStateSM --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 7 +++---- llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 556b04c2be323..224bbe7e38a19 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3108,13 +3108,12 @@ AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI, const TargetInstrInfo *TII = Subtarget->getInstrInfo(); Register ResultReg = MI.getOperand(0).getReg(); if (FuncInfo->isPStateSMRegUsed()) { + RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) - .addExternalSymbol("__arm_sme_state") + .addExternalSymbol(getLibcallName(LC)) .addReg(AArch64::X0, RegState::ImplicitDefine) - .addRegMask(TRI->getCallPreservedMask( - *MF, CallingConv:: - AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)); + .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC))); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg) .addReg(AArch64::X0); } else { diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index bb788fcebe4ae..934f68b29922a 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -93,7 +93,7 @@ void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, case RTLIB::SMEABI_SME_STATE_SIZE: case RTLIB::SMEABI_SME_SAVE: case RTLIB::SMEABI_SME_RESTORE: - KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine); + KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; break; case RTLIB::SMEABI_ZA_DISABLE: case RTLIB::SMEABI_TPIDR2_RESTORE: From ddba54e60b7ebf498cbed9e5a9436f6ca6bd46e6 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 13 Aug 2025 14:13:49 +0000 Subject: [PATCH 3/9] Avoid TLI.getSupportedLibcallImpl Rewrite to check a (much smaller) list of libcalls --- .../Target/AArch64/AArch64FrameLowering.cpp | 35 +++++------ .../AArch64/Utils/AArch64SMEAttributes.cpp | 58 ++++++++++--------- .../AArch64/Utils/AArch64SMEAttributes.h | 2 +- 3 files changed, 51 insertions(+), 44 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index ba02c82b25aaf..ae6b486e758c1 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1475,27 +1475,25 @@ static bool requiresSaveVG(const MachineFunction &MF) { return true; } -bool isVGInstruction(MachineBasicBlock::iterator MBBI) { +static bool matchLibcall(const TargetLowering &TLI, const MachineOperand &MO, + RTLIB::Libcall LC) { + return MO.isSymbol() && TLI.getLibcallName(LC) == MO.getSymbolName(); +} + +bool isVGInstruction(MachineBasicBlock::iterator MBBI, + const TargetLowering &TLI) { unsigned Opc = MBBI->getOpcode(); if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI || Opc == AArch64::UBFMXri) return true; - if (requiresGetVGCall(*MBBI->getMF())) { - if (Opc == AArch64::ORRXrr) - return true; + if (!requiresGetVGCall(*MBBI->getMF())) + return false; - if (Opc == AArch64::BL) { - auto Op1 = MBBI->getOperand(0); - auto &TLI = - *MBBI->getMF()->getSubtarget().getTargetLowering(); - char const *GetCurrentVG = - TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG); - return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG; - } - } + if (Opc == AArch64::BL) + return matchLibcall(TLI, MBBI->getOperand(0), RTLIB::SMEABI_GET_CURRENT_VG); - return false; + return Opc == AArch64::ORRXrr; } // Convert callee-save register save/restore instruction to do stack pointer @@ -1514,9 +1512,11 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec( // functions, we need to do this for both the streaming and non-streaming // vector length. Move past these instructions if necessary. MachineFunction &MF = *MBB.getParent(); - if (requiresSaveVG(MF)) - while (isVGInstruction(MBBI)) + if (requiresSaveVG(MF)) { + auto &TLI = *MF.getSubtarget().getTargetLowering(); + while (isVGInstruction(MBBI, TLI)) ++MBBI; + } switch (MBBI->getOpcode()) { default: @@ -2100,11 +2100,12 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF, // Move past the saves of the callee-saved registers, fixing up the offsets // and pre-inc if we decided to combine the callee-save and local stack // pointer bump above. + auto &TLI = *MF.getSubtarget().getTargetLowering(); while (MBBI != End && MBBI->getFlag(MachineInstr::FrameSetup) && !IsSVECalleeSave(MBBI)) { if (CombineSPBump && // Only fix-up frame-setup load/store instructions. - (!requiresSaveVG(MF) || !isVGInstruction(MBBI))) + (!requiresSaveVG(MF) || !isVGInstruction(MBBI, TLI))) fixupCalleeSaveRestoreStackOffset(*MBBI, AFI->getLocalStackSize(), NeedsWinCFI, &HasWinCFI); ++MBBI; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 934f68b29922a..86ef0e0058263 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -81,34 +81,40 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI) { - RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); - if (Impl == RTLIB::Unsupported) - return; - RTLIB::Libcall LC = RTLIB::RuntimeLibcallsInfo::getLibcallFromImpl(Impl); + struct SMERoutineAttr { + RTLIB::Libcall LC{RTLIB::UNKNOWN_LIBCALL}; + unsigned Attrs{SMEAttrs::Normal}; + }; + + static constexpr unsigned SMCompatiableABIRoutine = + SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; + static constexpr unsigned SMCompatiableABIRoutineInZA = + SMCompatiableABIRoutine | encodeZAState(StateValue::In); + + // Table of SME routine -> Known attributes. + static constexpr SMERoutineAttr SMERoutineAttrs[]{ + {RTLIB::SMEABI_SME_STATE, SMCompatiableABIRoutine}, + {RTLIB::SMEABI_TPIDR2_SAVE, SMCompatiableABIRoutine}, + {RTLIB::SMEABI_GET_CURRENT_VG, SMCompatiableABIRoutine}, + {RTLIB::SMEABI_SME_STATE_SIZE, SMCompatiableABIRoutine}, + {RTLIB::SMEABI_SME_SAVE, SMCompatiableABIRoutine}, + {RTLIB::SMEABI_SME_RESTORE, SMCompatiableABIRoutine}, + {RTLIB::SMEABI_ZA_DISABLE, SMCompatiableABIRoutineInZA}, + {RTLIB::SMEABI_TPIDR2_RESTORE, SMCompatiableABIRoutineInZA}, + {RTLIB::SC_MEMCPY, SMEAttrs::SM_Compatible}, + {RTLIB::SC_MEMMOVE, SMEAttrs::SM_Compatible}, + {RTLIB::SC_MEMSET, SMEAttrs::SM_Compatible}, + {RTLIB::SC_MEMCHR, SMEAttrs::SM_Compatible}, + }; + unsigned KnownAttrs = SMEAttrs::Normal; - switch (LC) { - case RTLIB::SMEABI_SME_STATE: - case RTLIB::SMEABI_TPIDR2_SAVE: - case RTLIB::SMEABI_GET_CURRENT_VG: - case RTLIB::SMEABI_SME_STATE_SIZE: - case RTLIB::SMEABI_SME_SAVE: - case RTLIB::SMEABI_SME_RESTORE: - KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; - break; - case RTLIB::SMEABI_ZA_DISABLE: - case RTLIB::SMEABI_TPIDR2_RESTORE: - KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) | - SMEAttrs::SME_ABI_Routine; - break; - case RTLIB::SC_MEMCPY: - case RTLIB::SC_MEMMOVE: - case RTLIB::SC_MEMSET: - case RTLIB::SC_MEMCHR: - KnownAttrs |= SMEAttrs::SM_Compatible; - break; - default: - break; + for (auto [LC, Attrs] : SMERoutineAttrs) { + if (TLI.getLibcallName(LC) == FuncName) { + KnownAttrs = Attrs; + break; + } } + set(KnownAttrs); } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index 06376c74025f8..bfbb5deec50bd 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -84,7 +84,7 @@ class SMEAttrs { static StateValue decodeZAState(unsigned Bitmask) { return static_cast((Bitmask & ZA_Mask) >> ZA_Shift); } - static unsigned encodeZAState(StateValue S) { + static constexpr unsigned encodeZAState(StateValue S) { return static_cast(S) << ZA_Shift; } From fc90c3dbde66415d6d2047f40ff01989c0bb0911 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 13 Aug 2025 14:28:41 +0000 Subject: [PATCH 4/9] Use stringref for check --- llvm/lib/Target/AArch64/AArch64FrameLowering.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index ae6b486e758c1..fddde668b7f1a 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1477,7 +1477,8 @@ static bool requiresSaveVG(const MachineFunction &MF) { static bool matchLibcall(const TargetLowering &TLI, const MachineOperand &MO, RTLIB::Libcall LC) { - return MO.isSymbol() && TLI.getLibcallName(LC) == MO.getSymbolName(); + return MO.isSymbol() && + StringRef(TLI.getLibcallName(LC)) == MO.getSymbolName(); } bool isVGInstruction(MachineBasicBlock::iterator MBBI, From 80b7661434b5809e9702a0622d254f8f6aa3f63c Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 13 Aug 2025 18:16:04 +0000 Subject: [PATCH 5/9] Check prefix --- llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 86ef0e0058263..7ccb7f2a4806b 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -81,6 +81,10 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI) { + // Skip functions that do not appear to be builtins (starting with "__"). + if (!FuncName.starts_with('_')) + return; + struct SMERoutineAttr { RTLIB::Libcall LC{RTLIB::UNKNOWN_LIBCALL}; unsigned Attrs{SMEAttrs::Normal}; From 00d1d91f75daf2ca46d2005643709ff373082c45 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 13 Aug 2025 18:33:31 +0000 Subject: [PATCH 6/9] Revert "Check prefix" This reverts commit 1da81a040218b5ef20e0f7d12c6d0088158abb51. --- llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 7ccb7f2a4806b..86ef0e0058263 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -81,10 +81,6 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI) { - // Skip functions that do not appear to be builtins (starting with "__"). - if (!FuncName.starts_with('_')) - return; - struct SMERoutineAttr { RTLIB::Libcall LC{RTLIB::UNKNOWN_LIBCALL}; unsigned Attrs{SMEAttrs::Normal}; From 9f9497c40c43c7f425cf01ed3c8c4c2cc2fa0f41 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 13 Aug 2025 20:24:24 +0000 Subject: [PATCH 7/9] Check prefix --- .../lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 11 +++++------ llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h | 11 ++++++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 86ef0e0058263..90ab5e0fbc456 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -14,12 +14,7 @@ using namespace llvm; -void SMEAttrs::set(unsigned M, bool Enable) { - if (Enable) - Bitmask |= M; - else - Bitmask &= ~M; - +void SMEAttrs::validate() const { // Streaming Mode Attrs assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) && "SM_Enabled and SM_Compatible are mutually exclusive"); @@ -81,6 +76,10 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI) { + // If the function name does not start with a _ or #_ is not a builtin. + if (!FuncName.starts_with('_') && !FuncName.starts_with("#_")) + return; + struct SMERoutineAttr { RTLIB::Libcall LC{RTLIB::UNKNOWN_LIBCALL}; unsigned Attrs{SMEAttrs::Normal}; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index bfbb5deec50bd..fee3a29cdece8 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -62,7 +62,15 @@ class SMEAttrs { addKnownFunctionAttrs(FuncName, TLI); }; - void set(unsigned M, bool Enable = true); + void set(unsigned M, bool Enable = true) { + if (Enable) + Bitmask |= M; + else + Bitmask &= ~M; +#ifndef NDEBUG + validate(); +#endif + } // Interfaces to query PSTATE.SM bool hasStreamingBody() const { return Bitmask & SM_Body; } @@ -149,6 +157,7 @@ class SMEAttrs { private: void addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI); + void validate() const; }; /// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has From 8235c92c0b0b8c62227d4290a663f6dbbb2aa084 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 15 Aug 2025 07:55:05 +0000 Subject: [PATCH 8/9] Re-try getSupportedLibcallImpl --- .../AArch64/Utils/AArch64SMEAttributes.cpp | 60 ++++++++----------- .../AArch64/Utils/AArch64SMEAttributes.h | 2 +- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 90ab5e0fbc456..fd2243eb7f09b 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -76,44 +76,34 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI) { - // If the function name does not start with a _ or #_ is not a builtin. - if (!FuncName.starts_with('_') && !FuncName.starts_with("#_")) + RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); + if (Impl == RTLIB::Unsupported) return; - - struct SMERoutineAttr { - RTLIB::Libcall LC{RTLIB::UNKNOWN_LIBCALL}; - unsigned Attrs{SMEAttrs::Normal}; - }; - - static constexpr unsigned SMCompatiableABIRoutine = - SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; - static constexpr unsigned SMCompatiableABIRoutineInZA = - SMCompatiableABIRoutine | encodeZAState(StateValue::In); - - // Table of SME routine -> Known attributes. - static constexpr SMERoutineAttr SMERoutineAttrs[]{ - {RTLIB::SMEABI_SME_STATE, SMCompatiableABIRoutine}, - {RTLIB::SMEABI_TPIDR2_SAVE, SMCompatiableABIRoutine}, - {RTLIB::SMEABI_GET_CURRENT_VG, SMCompatiableABIRoutine}, - {RTLIB::SMEABI_SME_STATE_SIZE, SMCompatiableABIRoutine}, - {RTLIB::SMEABI_SME_SAVE, SMCompatiableABIRoutine}, - {RTLIB::SMEABI_SME_RESTORE, SMCompatiableABIRoutine}, - {RTLIB::SMEABI_ZA_DISABLE, SMCompatiableABIRoutineInZA}, - {RTLIB::SMEABI_TPIDR2_RESTORE, SMCompatiableABIRoutineInZA}, - {RTLIB::SC_MEMCPY, SMEAttrs::SM_Compatible}, - {RTLIB::SC_MEMMOVE, SMEAttrs::SM_Compatible}, - {RTLIB::SC_MEMSET, SMEAttrs::SM_Compatible}, - {RTLIB::SC_MEMCHR, SMEAttrs::SM_Compatible}, - }; - unsigned KnownAttrs = SMEAttrs::Normal; - for (auto [LC, Attrs] : SMERoutineAttrs) { - if (TLI.getLibcallName(LC) == FuncName) { - KnownAttrs = Attrs; - break; - } + RTLIB::Libcall LC = RTLIB::RuntimeLibcallsInfo::getLibcallFromImpl(Impl); + switch (LC) { + case RTLIB::SMEABI_SME_STATE: + case RTLIB::SMEABI_TPIDR2_SAVE: + case RTLIB::SMEABI_GET_CURRENT_VG: + case RTLIB::SMEABI_SME_STATE_SIZE: + case RTLIB::SMEABI_SME_SAVE: + case RTLIB::SMEABI_SME_RESTORE: + KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; + break; + case RTLIB::SMEABI_ZA_DISABLE: + case RTLIB::SMEABI_TPIDR2_RESTORE: + KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) | + SMEAttrs::SME_ABI_Routine; + break; + case RTLIB::SC_MEMCPY: + case RTLIB::SC_MEMMOVE: + case RTLIB::SC_MEMSET: + case RTLIB::SC_MEMCHR: + KnownAttrs |= SMEAttrs::SM_Compatible; + break; + default: + break; } - set(KnownAttrs); } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index fee3a29cdece8..851436e3b7cce 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -92,7 +92,7 @@ class SMEAttrs { static StateValue decodeZAState(unsigned Bitmask) { return static_cast((Bitmask & ZA_Mask) >> ZA_Shift); } - static constexpr unsigned encodeZAState(StateValue S) { + static unsigned encodeZAState(StateValue S) { return static_cast(S) << ZA_Shift; } From a00a6e4da340f3b683be519f5554f9769d0a1201 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 15 Aug 2025 08:02:24 +0000 Subject: [PATCH 9/9] Fixup --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +- .../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 5 +++-- .../lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 6 +++--- llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h | 11 ++++++----- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 224bbe7e38a19..9b1d1c6e03870 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8598,7 +8598,7 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, } static SMECallAttrs -getSMECallAttrs(const Function &Caller, const TargetLowering &TLI, +getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI, const TargetLowering::CallLoweringInfo &CLI) { if (CLI.CB) return SMECallAttrs(*CLI.CB, &TLI); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index adf4b33b8021a..403c77f7aca35 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -220,7 +220,8 @@ static cl::opt EnableFixedwidthAutovecInStreamingMode( static cl::opt EnableScalableAutovecInStreamingMode( "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden); -static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) { +static bool isSMEABIRoutineCall(const CallInst &CI, + const AArch64TargetLowering &TLI) { const auto *F = CI.getCalledFunction(); return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine(); } @@ -229,7 +230,7 @@ static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) { /// lowered using incompatible instructions for the selected mode. This also /// returns true if the function F may use or modify ZA state. static bool hasPossibleIncompatibleOps(const Function *F, - const TargetLowering &TLI) { + const AArch64TargetLowering &TLI) { for (const BasicBlock &BB : *F) { for (const Instruction &I : BB) { // Be conservative for now and assume that any call to inline asm or to diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index fd2243eb7f09b..dd6fa167c6f4d 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "AArch64SMEAttributes.h" -#include "llvm/CodeGen/TargetLowering.h" +#include "AArch64ISelLowering.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/RuntimeLibcalls.h" #include @@ -75,7 +75,7 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { } void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, - const TargetLowering &TLI) { + const AArch64TargetLowering &TLI) { RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); if (Impl == RTLIB::Unsupported) return; @@ -124,7 +124,7 @@ bool SMECallAttrs::requiresSMChange() const { return true; } -SMECallAttrs::SMECallAttrs(const CallBase &CB, const TargetLowering *TLI) +SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI) : CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal), Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) { if (auto *CalledFunction = CB.getCalledFunction()) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index 851436e3b7cce..48f9da02d3182 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -13,7 +13,7 @@ namespace llvm { -class TargetLowering; +class AArch64TargetLowering; class Function; class CallBase; @@ -52,13 +52,13 @@ class SMEAttrs { SMEAttrs() = default; SMEAttrs(unsigned Mask) { set(Mask); } - SMEAttrs(const Function &F, const TargetLowering *TLI = nullptr) + SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr) : SMEAttrs(F.getAttributes()) { if (TLI) addKnownFunctionAttrs(F.getName(), *TLI); } SMEAttrs(const AttributeList &L); - SMEAttrs(StringRef FuncName, const TargetLowering &TLI) { + SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) { addKnownFunctionAttrs(FuncName, TLI); }; @@ -156,7 +156,8 @@ class SMEAttrs { } private: - void addKnownFunctionAttrs(StringRef FuncName, const TargetLowering &TLI); + void addKnownFunctionAttrs(StringRef FuncName, + const AArch64TargetLowering &TLI); void validate() const; }; @@ -174,7 +175,7 @@ class SMECallAttrs { SMEAttrs Callsite = SMEAttrs::Normal) : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} - SMECallAttrs(const CallBase &CB, const TargetLowering *TLI); + SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI); SMEAttrs &caller() { return CallerFn; } SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }