Skip to content

[AArch64][SME] Port all SME routines to RuntimeLibcalls #152505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2025

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Aug 7, 2025

This updates everywhere we emit/check an SME routines to use RuntimeLibcalls to get the function name and calling convention.

Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so I tweaked the output to avoid the need for variables.

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-ir

Author: Benjamin Maxwell (MacDue)

Changes

This updates everywhere we emit/check an SME routines to use RuntimeLibcalls to get the function name and calling convention.

Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so tweaked the output to avoid the need for variables.


Patch is 38.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152505.diff

14 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+6)
  • (modified) llvm/include/llvm/IR/RuntimeLibcalls.td (+42-1)
  • (modified) llvm/include/llvm/IR/RuntimeLibcallsImpl.td (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64FrameLowering.cpp (+10-6)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+20-20)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+7-11)
  • (modified) llvm/lib/Target/AArch64/SMEABIPass.cpp (+23-8)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+29-10)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+10-8)
  • (modified) llvm/test/TableGen/RuntimeLibcallEmitter-calling-conv.td (+16-48)
  • (modified) llvm/test/TableGen/RuntimeLibcallEmitter-conflict-warning.td (+7-7)
  • (modified) llvm/test/TableGen/RuntimeLibcallEmitter.td (+24-42)
  • (modified) llvm/unittests/Target/AArch64/SMEAttributesTest.cpp (+1-1)
  • (modified) llvm/utils/TableGen/Basic/RuntimeLibcallsEmitter.cpp (+15-18)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 01f8fb5ed061f..21062aa2675cc 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3558,6 +3558,12 @@ class LLVM_ABI TargetLoweringBase {
     return Libcalls.getLibcallImplName(Call);
   }
 
+  /// Check if this is valid libcall for the current module, otherwise
+  /// RTLIB::Unsupported.
+  RTLIB::LibcallImpl getSupportedLibcallImpl(StringRef FuncName) const {
+    return Libcalls.getSupportedLibcallImpl(FuncName);
+  }
+
   const char *getMemcpyName() const { return Libcalls.getMemcpyName(); }
 
   /// Get the comparison predicate that's to be used to test the result of the
diff --git a/llvm/include/llvm/IR/RuntimeLibcalls.td b/llvm/include/llvm/IR/RuntimeLibcalls.td
index df472d4b9cfee..e672edeeeda15 100644
--- a/llvm/include/llvm/IR/RuntimeLibcalls.td
+++ b/llvm/include/llvm/IR/RuntimeLibcalls.td
@@ -405,6 +405,17 @@ multiclass LibmLongDoubleLibCall<string libcall_basename = !toupper(NAME),
 def SC_MEMCPY : RuntimeLibcall;
 def SC_MEMMOVE : RuntimeLibcall;
 def SC_MEMSET : RuntimeLibcall;
+def SC_MEMCHR: RuntimeLibcall;
+
+// AArch64 SME ABI calls
+def SMEABI_SME_STATE : RuntimeLibcall;
+def SMEABI_TPIDR2_SAVE : RuntimeLibcall;
+def SMEABI_ZA_DISABLE : RuntimeLibcall;
+def SMEABI_TPIDR2_RESTORE : RuntimeLibcall;
+def SMEABI_GET_CURRENT_VG : RuntimeLibcall;
+def SMEABI_SME_STATE_SIZE : RuntimeLibcall;
+def SMEABI_SME_SAVE : RuntimeLibcall;
+def SMEABI_SME_RESTORE : RuntimeLibcall;
 
 // ARM EABI calls
 def AEABI_MEMCPY4 : RuntimeLibcall; // Align 4
@@ -1223,8 +1234,35 @@ defset list<RuntimeLibcallImpl> AArch64LibcallImpls = {
   def __arm_sc_memcpy : RuntimeLibcallImpl<SC_MEMCPY>;
   def __arm_sc_memmove : RuntimeLibcallImpl<SC_MEMMOVE>;
   def __arm_sc_memset : RuntimeLibcallImpl<SC_MEMSET>;
+  def __arm_sc_memchr : RuntimeLibcallImpl<SC_MEMCHR>;
 } // End AArch64LibcallImpls
 
+def __arm_sme_state : RuntimeLibcallImpl<SMEABI_SME_STATE>;
+def __arm_tpidr2_save : RuntimeLibcallImpl<SMEABI_TPIDR2_SAVE>;
+def __arm_za_disable : RuntimeLibcallImpl<SMEABI_ZA_DISABLE>;
+def __arm_tpidr2_restore : RuntimeLibcallImpl<SMEABI_TPIDR2_RESTORE>;
+def __arm_get_current_vg : RuntimeLibcallImpl<SMEABI_GET_CURRENT_VG>;
+def __arm_sme_state_size : RuntimeLibcallImpl<SMEABI_SME_STATE_SIZE>;
+def __arm_sme_save : RuntimeLibcallImpl<SMEABI_SME_SAVE>;
+def __arm_sme_restore : RuntimeLibcallImpl<SMEABI_SME_RESTORE>;
+
+def SMEABI_LibCalls_PreserveMost_From_X0 : LibcallsWithCC<(add
+  __arm_tpidr2_save,
+  __arm_za_disable,
+  __arm_tpidr2_restore),
+  SMEABI_PreserveMost_From_X0>;
+
+def SMEABI_LibCalls_PreserveMost_From_X1 : LibcallsWithCC<(add
+  __arm_get_current_vg,
+  __arm_sme_state_size,
+  __arm_sme_save,
+  __arm_sme_restore),
+  SMEABI_PreserveMost_From_X1>;
+
+def SMEABI_LibCalls_PreserveMost_From_X2 : LibcallsWithCC<(add
+  __arm_sme_state),
+  SMEABI_PreserveMost_From_X2>;
+
 def isAArch64_ExceptArm64EC
     : RuntimeLibcallPredicate<"(TT.isAArch64() && !TT.isWindowsArm64EC())">;
 def isWindowsArm64EC : RuntimeLibcallPredicate<"TT.isWindowsArm64EC()">;
@@ -1244,7 +1282,10 @@ def AArch64SystemLibrary : SystemRuntimeLibrary<
        LibmHasSinCosF32, LibmHasSinCosF64, LibmHasSinCosF128,
        DefaultLibmExp10,
        DefaultStackProtector,
-       SecurityCheckCookieIfWinMSVC)
+       SecurityCheckCookieIfWinMSVC,
+       SMEABI_LibCalls_PreserveMost_From_X0,
+       SMEABI_LibCalls_PreserveMost_From_X1,
+       SMEABI_LibCalls_PreserveMost_From_X2)
 >;
 
 // Prepend a # to every name
diff --git a/llvm/include/llvm/IR/RuntimeLibcallsImpl.td b/llvm/include/llvm/IR/RuntimeLibcallsImpl.td
index 601c291daf89d..b5752c1b69ad8 100644
--- a/llvm/include/llvm/IR/RuntimeLibcallsImpl.td
+++ b/llvm/include/llvm/IR/RuntimeLibcallsImpl.td
@@ -36,6 +36,9 @@ def ARM_AAPCS : LibcallCallingConv<[{CallingConv::ARM_AAPCS}]>;
 def ARM_AAPCS_VFP : LibcallCallingConv<[{CallingConv::ARM_AAPCS_VFP}]>;
 def X86_STDCALL : LibcallCallingConv<[{CallingConv::X86_StdCall}]>;
 def AVR_BUILTIN : LibcallCallingConv<[{CallingConv::AVR_BUILTIN}]>;
+def SMEABI_PreserveMost_From_X0 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0}]>;
+def SMEABI_PreserveMost_From_X1 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1}]>;
+def SMEABI_PreserveMost_From_X2 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2}]>;
 
 /// Abstract definition for functionality the compiler may need to
 /// emit a call to. Emits the RTLIB::Libcall enum - This enum defines
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 885f2a94f85f5..ba02c82b25aaf 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -1487,8 +1487,11 @@ bool isVGInstruction(MachineBasicBlock::iterator MBBI) {
 
     if (Opc == AArch64::BL) {
       auto Op1 = MBBI->getOperand(0);
-      return Op1.isSymbol() &&
-             (StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
+      auto &TLI =
+          *MBBI->getMF()->getSubtarget<AArch64Subtarget>().getTargetLowering();
+      char const *GetCurrentVG =
+          TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG);
+      return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG;
     }
   }
 
@@ -3468,6 +3471,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
     MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
     ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
   MachineFunction &MF = *MBB.getParent();
+  auto &TLI = *MF.getSubtarget<AArch64Subtarget>().getTargetLowering();
   const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
   AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
   bool NeedsWinCFI = needsWinCFI(MF);
@@ -3581,11 +3585,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
               .addReg(AArch64::X0, RegState::Implicit)
               .setMIFlag(MachineInstr::FrameSetup);
 
-        const uint32_t *RegMask = TRI->getCallPreservedMask(
-            MF,
-            CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
+        RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG;
+        const uint32_t *RegMask =
+            TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC));
         BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
-            .addExternalSymbol("__arm_get_current_vg")
+            .addExternalSymbol(TLI.getLibcallName(LC))
             .addRegMask(RegMask)
             .addReg(AArch64::X0, RegState::ImplicitDefine)
             .setMIFlag(MachineInstr::FrameSetup);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a40de86b4615b..e7f583a601e90 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3083,13 +3083,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
   AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
   if (FuncInfo->isSMESaveBufferUsed()) {
+    RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
     const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
-        .addExternalSymbol("__arm_sme_state_size")
+        .addExternalSymbol(getLibcallName(LC))
         .addReg(AArch64::X0, RegState::ImplicitDefine)
-        .addRegMask(TRI->getCallPreservedMask(
-            *MF, CallingConv::
-                     AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+        .addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
     BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
             MI.getOperand(0).getReg())
         .addReg(AArch64::X0);
@@ -5711,15 +5710,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
 SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
                                                   SDValue Chain, SDLoc DL,
                                                   EVT VT) const {
-  SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
+  RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
+  SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
                                          getPointerTy(DAG.getDataLayout()));
   Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
   Type *RetTy = StructType::get(Int64Ty, Int64Ty);
   TargetLowering::CallLoweringInfo CLI(DAG);
   ArgListTy Args;
   CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
-      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
-      RetTy, Callee, std::move(Args));
+      getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
   std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
   SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
   return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
@@ -8564,12 +8563,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
 }
 
 static SMECallAttrs
-getSMECallAttrs(const Function &Caller,
+getSMECallAttrs(const Function &Caller, const TargetLowering &TLI,
                 const TargetLowering::CallLoweringInfo &CLI) {
   if (CLI.CB)
-    return SMECallAttrs(*CLI.CB);
+    return SMECallAttrs(*CLI.CB, &TLI);
   if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
-    return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
+    return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
   return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
 }
 
@@ -8591,7 +8590,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
 
   // SME Streaming functions are not eligible for TCO as they may require
   // the streaming mode or ZA to be restored after returning from the call.
-  SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
+  SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
   if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
       CallAttrs.requiresPreservingAllZAState() ||
       CallAttrs.caller().hasStreamingBody())
@@ -8879,14 +8878,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
       DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
   Args.push_back(Entry);
 
-  SDValue Callee =
-      DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
-                            TLI.getPointerTy(DAG.getDataLayout()));
+  RTLIB::Libcall LC =
+      IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
+  SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
+                                         TLI.getPointerTy(DAG.getDataLayout()));
   auto *RetTy = Type::getVoidTy(*DAG.getContext());
   TargetLowering::CallLoweringInfo CLI(DAG);
   CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
-      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
-      Callee, std::move(Args));
+      TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
   return TLI.LowerCallTo(CLI).second;
 }
 
@@ -9074,7 +9073,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   }
 
   // Determine whether we need any streaming mode changes.
-  SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
+  SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
 
   auto DescribeCallsite =
       [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9659,11 +9658,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   if (RequiresLazySave) {
     // Conditionally restore the lazy save using a pseudo node.
+    RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
     TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     SDValue RegMask = DAG.getRegisterMask(
-        TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+        TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
     SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
-        "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
+        getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
     SDValue TPIDR2_EL0 = DAG.getNode(
         ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
         DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
@@ -29004,7 +29004,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
 
   // Checks to allow the use of SME instructions
   if (auto *Base = dyn_cast<CallBase>(&Inst)) {
-    auto CallAttrs = SMECallAttrs(*Base);
+    auto CallAttrs = SMECallAttrs(*Base, this);
     if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
         CallAttrs.requiresPreservingZT0() ||
         CallAttrs.requiresPreservingAllZAState())
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 9f05add8bc1c1..6b698f76e0384 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -220,20 +220,16 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
 
-static bool isSMEABIRoutineCall(const CallInst &CI) {
+static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) {
   const auto *F = CI.getCalledFunction();
-  return F && StringSwitch<bool>(F->getName())
-                  .Case("__arm_sme_state", true)
-                  .Case("__arm_tpidr2_save", true)
-                  .Case("__arm_tpidr2_restore", true)
-                  .Case("__arm_za_disable", true)
-                  .Default(false);
+  return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
 }
 
 /// Returns true if the function has explicit operations that can only be
 /// lowered using incompatible instructions for the selected mode. This also
 /// returns true if the function F may use or modify ZA state.
-static bool hasPossibleIncompatibleOps(const Function *F) {
+static bool hasPossibleIncompatibleOps(const Function *F,
+                                       const TargetLowering &TLI) {
   for (const BasicBlock &BB : *F) {
     for (const Instruction &I : BB) {
       // Be conservative for now and assume that any call to inline asm or to
@@ -242,7 +238,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
       // all native LLVM instructions can be lowered to compatible instructions.
       if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
           (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
-           isSMEABIRoutineCall(cast<CallInst>(I))))
+           isSMEABIRoutineCall(cast<CallInst>(I), TLI)))
         return true;
     }
   }
@@ -290,7 +286,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
   if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
       CallAttrs.requiresPreservingZT0() ||
       CallAttrs.requiresPreservingAllZAState()) {
-    if (hasPossibleIncompatibleOps(Callee))
+    if (hasPossibleIncompatibleOps(Callee, *getTLI()))
       return false;
   }
 
@@ -357,7 +353,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
   // change only once and avoid inlining of G into F.
 
   SMEAttrs FAttrs(*F);
-  SMECallAttrs CallAttrs(Call);
+  SMECallAttrs CallAttrs(Call, getTLI());
 
   if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
     if (F == Call.getCaller()) // (1)
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 4af4d49306625..2008516885c35 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -15,11 +15,16 @@
 #include "AArch64.h"
 #include "Utils/AArch64SMEAttributes.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/RuntimeLibcalls.h"
+#include "llvm/Target/TargetMachine.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 
 using namespace llvm;
@@ -33,9 +38,13 @@ struct SMEABI : public FunctionPass {
 
   bool runOnFunction(Function &F) override;
 
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<TargetPassConfig>();
+  }
+
 private:
   bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder,
-                               SMEAttrs FnAttrs);
+                               SMEAttrs FnAttrs, const TargetLowering &TLI);
 };
 } // end anonymous namespace
 
@@ -51,14 +60,16 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
 //===----------------------------------------------------------------------===//
 
 // Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
-void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
+void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, const TargetLowering &TLI,
+                    bool ZT0IsUndef = false) {
   auto &Ctx = M->getContext();
   auto *TPIDR2SaveTy =
       FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
   auto Attrs =
       AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
+  RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_SAVE;
   FunctionCallee Callee =
-      M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
+      M->getOrInsertFunction(TLI.getLibcallName(LC), TPIDR2SaveTy, Attrs);
   CallInst *Call = Builder.CreateCall(Callee);
 
   // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
@@ -67,8 +78,7 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
   if (ZT0IsUndef)
     Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));
 
-  Call->setCallingConv(
-      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
+  Call->setCallingConv(TLI.getLibcallCallingConv(LC));
 
   // A save to TPIDR2 should be followed by clearing TPIDR2_EL0.
   Function *WriteIntr =
@@ -98,7 +108,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
 /// interface if it does not share ZA or ZT0.
 ///
 bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
-                                     IRBuilder<> &Builder, SMEAttrs FnAttrs) {
+                                     IRBuilder<> &Builder, SMEAttrs FnAttrs,
+                                     const TargetLowering &TLI) {
   LLVMContext &Context = F->getContext();
   BasicBlock *OrigBB = &F->getEntryBlock();
   Builder.SetInsertPoint(&OrigBB->front());
@@ -124,7 +135,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
 
     // Create a call __arm_tpidr2_save, which commits the lazy save.
     Builder.SetInsertPoint(&SaveBB->back());
-    emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
+    emitTPIDR2Save(M, Builder, TLI, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
 
     // Enable pstate.za at the start of the function.
     Builder.SetInsertPoint(&OrigBB->front());
@@ -172,10 +183,14 @@ bool SMEABI::runOnFunction(Function &F) {
   if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za"))
     return false;
 
+  const TargetMachine &TM =
+      getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+  const TargetLowering &TLI = *TM.getSubtargetImpl(F)->getTargetLowering();
+
   bool Changed = false;
   SMEAttrs FnAttrs(F);
   if (FnAttrs.isNewZA() || FnAttrs.isNewZT0())
-    Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);
+    Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs, TLI);
 
   return Changed;
 }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 271094f935e0e..bb788fcebe4ae 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/U...
[truncated]

@arsenm
Copy link
Contributor

arsenm commented Aug 7, 2025

Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so I tweaked the output to avoid the need for variables.

What was this issue? Do you have an example?

@MacDue
Copy link
Member Author

MacDue commented Aug 7, 2025

What was this issue? Do you have an example?

It'd make LibraryCalls_AlwaysAvailable for each SMEABI_LibCalls_PreserveMost_From_{X0,X1,X2} collection of libcalls. Just removing the variables was simpler than worrying about the names.

@MacDue
Copy link
Member Author

MacDue commented Aug 11, 2025

Kind ping :)

This updates everywhere we emit/check an SME routines to use
RuntimeLibcalls to get the function name and calling convention.

Note: RuntimeLibcallEmitter had some issues with emitting non-unique
variable names for sets of libcalls, so tweaked the output to avoid
the need for variables.
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you should be checking if the call is available before emitting it, but that isn't meaningfully testable yet until RuntimeLibcalls is turned into a proper analysis

@MacDue MacDue merged commit 271688b into llvm:main Aug 13, 2025
9 checks passed
@MacDue MacDue deleted the sme_runtimelibs branch August 13, 2025 07:49
@nikic
Copy link
Contributor

nikic commented Aug 13, 2025

@MacDue
Copy link
Member Author

MacDue commented Aug 13, 2025

This caused a huge compile-time regression for aarch64 builds: https://llvm-compile-time-tracker.com/compare.php?from=b9138bde3562de5c28a239dbd303caf2406678c6&to=271688b87abe7cf45aceaff8266270a25eb7b436&stat=instructions:u

Hm... I suspect it must be the TLI.getSupportedLibcallImpl() check, as that's the main difference. I'll try rewriting without that.

@arsenm
Copy link
Contributor

arsenm commented Aug 13, 2025

Hm... I suspect it must be the TLI.getSupportedLibcallImpl() check, as that's the main difference. I'll try rewriting without that.

Which #150192 makes much faster

@@ -110,11 +129,11 @@ bool SMECallAttrs::requiresSMChange() const {
return true;
}

SMECallAttrs::SMECallAttrs(const CallBase &CB)
SMECallAttrs::SMECallAttrs(const CallBase &CB, const TargetLowering *TLI)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also use the AArch64 subclass of TargetLowering

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants