Skip to content

[AArch64][SME] Use entry pstate.sm for conditional streaming-mode changes #152169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 12, 2025

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Aug 5, 2025

We only do conditional streaming mode changes in two cases:

  • Around calls in streaming-compatible functions that don't have a streaming body
  • At the entry/exit of streaming-compatible functions with a streaming body

In both cases, the condition depends on the entry pstate.sm value. Given this, we don't need to emit calls to __arm_sme_state at every mode change.

This patch handles this by placing a "AArch64ISD::ENTRY_PSTATE_SM" node in the entry block and copying the result to a register. The register is then used whenever we need to emit a conditional streaming mode change. The "ENTRY_PSTATE_SM" node expands to a call to "__arm_sme_state" only if (after SelectionDAG) the function is determined to have streaming-mode changes.

This has two main advantages:

  1. It allows back-to-back conditional smstart/stop pairs to be folded
  2. It has the correct behaviour for EH landing pads
    • These are entered with pstate.sm = 0, and should switch mode based on the entry pstate.sm
    • Note: This is not fully implemented yet

@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: Benjamin Maxwell (MacDue)

Changes

We only do conditional streaming mode changes in two cases:

  • Around calls in streaming-compatible functions that don't have a streaming body
  • At the entry/exit of streaming-compatible functions with a streaming body

In both cases, the condition depends on the entry pstate.sm value. Given this, we don't need to emit calls to __arm_sme_state at every mode change.

This patch handles this by placing a "AArch64ISD::CALLER_IS_STREAMING" node in the entry block and copying the result to a register. The register is then used whenever we need to emit a conditional streaming mode change. The "CALLER_IS_STREAMING" node expands to a call to "__arm_sme_state" o nly if (after SelectionDAG) the function is determined to have streaming-mode changes.

This has two main advantages:

  1. It allows back-to-back conditional smstart/stop pairs to be folded
  2. It has the correct behaviour for EH landing pads
  • These are entered with pstate.sm = 0, and should switch mode based on the entry pstate.sm
  • Note: This is not fully implemented yet

Patch is 50.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152169.diff

19 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp (+3-1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+61-30)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+4-2)
  • (modified) llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h (+6)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+9)
  • (modified) llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp (+3-9)
  • (modified) llvm/test/CodeGen/AArch64/sme-agnostic-za.ll (+13-15)
  • (modified) llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll (+2-2)
  • (modified) llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll (+2-2)
  • (modified) llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll (+1-1)
  • (modified) llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll (+2-2)
  • (modified) llvm/test/CodeGen/AArch64/sme-peephole-opts.ll (+3-18)
  • (modified) llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll (+4-5)
  • (modified) llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll (+14-13)
  • (modified) llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll (+6-5)
  • (modified) llvm/test/CodeGen/AArch64/spill-reload-remarks.ll (+1-2)
  • (modified) llvm/test/CodeGen/AArch64/stack-hazard.ll (+60-63)
  • (modified) llvm/test/CodeGen/AArch64/streaming-compatible-memory-ops.ll (+3-4)
  • (modified) llvm/test/CodeGen/AArch64/sve-stack-frame-layout.ll (+2-2)
diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
index 8c8daef6dccd4..763b3868a99ca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
@@ -1178,7 +1178,9 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
   if (Node->getValueType(Node->getNumValues()-1) == MVT::Glue) {
     for (SDNode *F = Node->getGluedUser(); F; F = F->getGluedUser()) {
       if (F->getOpcode() == ISD::CopyFromReg) {
-        UsedRegs.push_back(cast<RegisterSDNode>(F->getOperand(1))->getReg());
+        Register Reg = cast<RegisterSDNode>(F->getOperand(1))->getReg();
+        if (Reg.isPhysical())
+          UsedRegs.push_back(Reg);
         continue;
       } else if (F->getOpcode() == ISD::CopyToReg) {
         // Skip CopyToReg nodes that are internal to the glue chain.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2b6ea86ee1af5..dc679e5baa196 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3101,6 +3101,32 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
   return BB;
 }
 
+MachineBasicBlock *
+AArch64TargetLowering::EmitCallerIsStreaming(MachineInstr &MI,
+                                             MachineBasicBlock *BB) const {
+  MachineFunction *MF = BB->getParent();
+  AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  if (FuncInfo->IsPStateSMRegUsed()) {
+    const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+        .addExternalSymbol("__arm_sme_state")
+        .addReg(AArch64::X0, RegState::ImplicitDefine)
+        .addRegMask(TRI->getCallPreservedMask(
+            *MF, CallingConv::
+                     AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2));
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            MI.getOperand(0).getReg())
+        .addReg(AArch64::X0);
+  } else {
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            MI.getOperand(0).getReg())
+        .addReg(AArch64::XZR);
+  }
+  BB->remove_instr(&MI);
+  return BB;
+}
+
 // Helper function to find the instruction that defined a virtual register.
 // If unable to find such instruction, returns nullptr.
 static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI,
@@ -3216,6 +3242,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitAllocateSMESaveBuffer(MI, BB);
   case AArch64::GetSMESaveSize:
     return EmitGetSMESaveSize(MI, BB);
+  case AArch64::CallerIsStreaming:
+    return EmitCallerIsStreaming(MI, BB);
   case AArch64::F128CSEL:
     return EmitF128CSEL(MI, BB);
   case TargetOpcode::STATEPOINT:
@@ -8132,19 +8160,26 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
   }
   assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
 
+  if (Attrs.hasStreamingCompatibleInterface()) {
+    SDValue CallerIsStreaming =
+        DAG.getNode(AArch64ISD::CALLER_IS_STREAMING, DL,
+                    DAG.getVTList(MVT::i64, MVT::Other), {Chain});
+
+    // Copy the value to a virtual register, and save that in FuncInfo.
+    Register CallerIsStreamingReg =
+        MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+    Chain = DAG.getCopyToReg(CallerIsStreaming.getValue(1), DL,
+                             CallerIsStreamingReg, CallerIsStreaming);
+    FuncInfo->setPStateSMReg(CallerIsStreamingReg);
+  }
+
   // Insert the SMSTART if this is a locally streaming function and
   // make sure it is Glued to the last CopyFromReg value.
   if (IsLocallyStreaming) {
-    SDValue PStateSM;
-    if (Attrs.hasStreamingCompatibleInterface()) {
-      PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
-      Register Reg = MF.getRegInfo().createVirtualRegister(
-          getRegClassFor(PStateSM.getValueType().getSimpleVT()));
-      FuncInfo->setPStateSMReg(Reg);
-      Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
+    if (Attrs.hasStreamingCompatibleInterface())
       Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
-                                  AArch64SME::IfCallerIsNonStreaming, PStateSM);
-    } else
+                                  AArch64SME::IfCallerIsNonStreaming);
+    else
       Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
                                   AArch64SME::Always);
 
@@ -8834,8 +8869,7 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
 SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
                                                    bool Enable, SDValue Chain,
                                                    SDValue InGlue,
-                                                   unsigned Condition,
-                                                   SDValue PStateSM) const {
+                                                   unsigned Condition) const {
   MachineFunction &MF = DAG.getMachineFunction();
   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
   FuncInfo->setHasStreamingModeChanges(true);
@@ -8847,9 +8881,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
   SmallVector<SDValue> Ops = {Chain, MSROp};
   unsigned Opcode;
   if (Condition != AArch64SME::Always) {
+    FuncInfo->setPStateSMRegUsed(true);
+    Register PStateReg = FuncInfo->getPStateSMReg();
+    assert(PStateReg.isValid() && "PStateSM Register is invalid");
+    SDValue PStateSM =
+        DAG.getCopyFromReg(Chain, DL, PStateReg, MVT::i64, InGlue);
+    // Use chain and glue from the CopyFromReg.
+    Ops[0] = PStateSM.getValue(1);
+    InGlue = PStateSM.getValue(2);
     SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
     Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
-    assert(PStateSM && "PStateSM should be defined");
     Ops.push_back(ConditionOp);
     Ops.push_back(PStateSM);
   } else {
@@ -9124,15 +9165,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                     /*IsSave=*/true);
   }
 
-  SDValue PStateSM;
   bool RequiresSMChange = CallAttrs.requiresSMChange();
   if (RequiresSMChange) {
-    if (CallAttrs.caller().hasStreamingInterfaceOrBody())
-      PStateSM = DAG.getConstant(1, DL, MVT::i64);
-    else if (CallAttrs.caller().hasNonStreamingInterface())
-      PStateSM = DAG.getConstant(0, DL, MVT::i64);
-    else
-      PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
     OptimizationRemarkEmitter ORE(&MF.getFunction());
     ORE.emit([&]() {
       auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -9447,9 +9481,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       InGlue = Chain.getValue(1);
     }
 
-    SDValue NewChain = changeStreamingMode(
-        DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
-        getSMToggleCondition(CallAttrs), PStateSM);
+    SDValue NewChain =
+        changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
+                            Chain, InGlue, getSMToggleCondition(CallAttrs));
     Chain = NewChain.getValue(0);
     InGlue = NewChain.getValue(1);
   }
@@ -9633,10 +9667,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     InGlue = Result.getValue(Result->getNumValues() - 1);
 
   if (RequiresSMChange) {
-    assert(PStateSM && "Expected a PStateSM to be set");
     Result = changeStreamingMode(
         DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
-        getSMToggleCondition(CallAttrs), PStateSM);
+        getSMToggleCondition(CallAttrs));
 
     if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
       InGlue = Result.getValue(1);
@@ -9802,14 +9835,11 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   // Emit SMSTOP before returning from a locally streaming function
   SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
   if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
-    if (FuncAttrs.hasStreamingCompatibleInterface()) {
-      Register Reg = FuncInfo->getPStateSMReg();
-      assert(Reg.isValid() && "PStateSM Register is invalid");
-      SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
+    if (FuncAttrs.hasStreamingCompatibleInterface())
       Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
                                   /*Glue*/ SDValue(),
-                                  AArch64SME::IfCallerIsNonStreaming, PStateSM);
-    } else
+                                  AArch64SME::IfCallerIsNonStreaming);
+    else
       Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
                                   /*Glue*/ SDValue(), AArch64SME::Always);
     Glue = Chain.getValue(1);
@@ -28166,6 +28196,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
     case Intrinsic::aarch64_sme_in_streaming_mode: {
       SDLoc DL(N);
       SDValue Chain = DAG.getEntryNode();
+
       SDValue RuntimePStateSM =
           getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
       Results.push_back(
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 88876570ac811..cc23f7e0bdfcd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -181,6 +181,8 @@ class AArch64TargetLowering : public TargetLowering {
                                                MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI,
                                         MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitCallerIsStreaming(MachineInstr &MI,
+                                           MachineBasicBlock *BB) const;
 
   /// Replace (0, vreg) discriminator components with the operands of blend
   /// or with (immediate, NoRegister) when possible.
@@ -523,8 +525,8 @@ class AArch64TargetLowering : public TargetLowering {
   /// node. \p Condition should be one of the enum values from
   /// AArch64SME::ToggleCondition.
   SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
-                              SDValue Chain, SDValue InGlue, unsigned Condition,
-                              SDValue PStateSM = SDValue()) const;
+                              SDValue Chain, SDValue InGlue,
+                              unsigned Condition) const;
 
   bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
 
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index 800787cc0b4f5..cb9fdb7606329 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -231,6 +231,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   // on function entry to record the initial pstate of a function.
   Register PStateSMReg = MCRegister::NoRegister;
 
+  // true if PStateSMReg is used.
+  bool PStateSMRegUsed = false;
+
   // Holds a pointer to a buffer that is large enough to represent
   // all SME ZA state and any additional state required by the
   // __arm_sme_save/restore support routines.
@@ -274,6 +277,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   Register getPStateSMReg() const { return PStateSMReg; };
   void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
 
+  unsigned IsPStateSMRegUsed() const { return PStateSMRegUsed; };
+  void setPStateSMRegUsed(bool Used = true) { PStateSMRegUsed = Used; };
+
   int64_t getVGIdx() const { return VGIdx; };
   void setVGIdx(unsigned Idx) { VGIdx = Idx; };
 
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index db27ca978980f..7b5f45e96a942 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -39,6 +39,15 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
 def AArch64CoalescerBarrier
     : SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;
 
+def AArch64CallerIsStreaming
+  : SDNode<"AArch64ISD::CALLER_IS_STREAMING", SDTypeProfile<1, 0,
+    [SDTCisInt<0>]>, [SDNPHasChain, SDNPSideEffect]>;
+
+let usesCustomInserter = 1 in {
+  def CallerIsStreaming : Pseudo<(outs GPR64:$is_streaming), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 (AArch64CallerIsStreaming)), (CallerIsStreaming)>;
+
 def AArch64VGSave : SDNode<"AArch64ISD::VG_SAVE", SDTypeProfile<0, 0, []>,
                            [SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;
 
diff --git a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
index bd28716118880..564af6708e1ed 100644
--- a/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
+++ b/llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
@@ -80,16 +80,10 @@ static bool isMatchingStartStopPair(const MachineInstr *MI1,
   if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
     return false;
 
-  // This optimisation is unlikely to happen in practice for conditional
-  // smstart/smstop pairs as the virtual registers for pstate.sm will always
-  // be different.
-  // TODO: For this optimisation to apply to conditional smstart/smstop,
-  // this pass will need to do more work to remove redundant calls to
-  // __arm_sme_state.
-
   // Only consider conditional start/stop pairs which read the same register
-  // holding the original value of pstate.sm, as some conditional start/stops
-  // require the state on entry to the function.
+  // holding the original value of pstate.sm. This is somewhat over conservative
+  // as all conditional streaming mode changes only look at the state on entry
+  // to the function.
   if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
     Register Reg1 = MI1->getOperand(3).getReg();
     Register Reg2 = MI2->getOperand(3).getReg();
diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
index 1f68815411097..ba40ccd1c7406 100644
--- a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
+++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
@@ -150,42 +150,40 @@ define i64 @streaming_compatible_agnostic_caller_nonstreaming_private_za_callee(
 ; CHECK-NEXT:    add x29, sp, #64
 ; CHECK-NEXT:    stp x20, x19, [sp, #96] // 16-byte Folded Spill
 ; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    mov x19, x0
 ; CHECK-NEXT:    bl __arm_sme_state_size
 ; CHECK-NEXT:    sub sp, sp, x0
-; CHECK-NEXT:    mov x19, sp
-; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    mov x20, sp
+; CHECK-NEXT:    mov x0, x20
 ; CHECK-NEXT:    bl __arm_sme_save
-; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    and x20, x0, #0x1
-; CHECK-NEXT:    tbz w20, #0, .LBB5_2
+; CHECK-NEXT:    tbz w19, #0, .LBB5_2
 ; CHECK-NEXT:  // %bb.1:
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:  .LBB5_2:
 ; CHECK-NEXT:    mov x0, x8
 ; CHECK-NEXT:    bl private_za_decl
-; CHECK-NEXT:    mov x2, x0
-; CHECK-NEXT:    tbz w20, #0, .LBB5_4
+; CHECK-NEXT:    mov x1, x0
+; CHECK-NEXT:    tbz w19, #0, .LBB5_4
 ; CHECK-NEXT:  // %bb.3:
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:  .LBB5_4:
-; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    mov x0, x20
 ; CHECK-NEXT:    bl __arm_sme_restore
-; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    mov x0, x20
 ; CHECK-NEXT:    bl __arm_sme_save
-; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    and x20, x0, #0x1
-; CHECK-NEXT:    tbz w20, #0, .LBB5_6
+; CHECK-NEXT:    tbz w19, #0, .LBB5_6
 ; CHECK-NEXT:  // %bb.5:
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:  .LBB5_6:
-; CHECK-NEXT:    mov x0, x2
+; CHECK-NEXT:    mov x0, x1
 ; CHECK-NEXT:    bl private_za_decl
 ; CHECK-NEXT:    mov x1, x0
-; CHECK-NEXT:    tbz w20, #0, .LBB5_8
+; CHECK-NEXT:    tbz w19, #0, .LBB5_8
 ; CHECK-NEXT:  // %bb.7:
 ; CHECK-NEXT:    smstart sm
 ; CHECK-NEXT:  .LBB5_8:
-; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    mov x0, x20
 ; CHECK-NEXT:    bl __arm_sme_restore
 ; CHECK-NEXT:    mov x0, x1
 ; CHECK-NEXT:    sub sp, x29, #64
diff --git a/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll b/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll
index c4440e7bcc3ff..1567ca258cccb 100644
--- a/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll
+++ b/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll
@@ -18,7 +18,7 @@ define void @streaming_compatible() #0 {
 ; CHECK-NEXT:    bl __arm_get_current_vg
 ; CHECK-NEXT:    stp x0, x19, [sp, #72] // 16-byte Folded Spill
 ; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    mov x19, x0
 ; CHECK-NEXT:    tbz w19, #0, .LBB0_2
 ; CHECK-NEXT:  // %bb.1:
 ; CHECK-NEXT:    smstop sm
@@ -57,7 +57,7 @@ define void @streaming_compatible_arg(float %f) #0 {
 ; CHECK-NEXT:    stp x0, x19, [sp, #88] // 16-byte Folded Spill
 ; CHECK-NEXT:    str s0, [sp, #12] // 4-byte Folded Spill
 ; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    mov x19, x0
 ; CHECK-NEXT:    tbz w19, #0, .LBB1_2
 ; CHECK-NEXT:  // %bb.1:
 ; CHECK-NEXT:    smstop sm
diff --git a/llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll b/llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
index 980144d6ca584..1933eb85b77f2 100644
--- a/llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
+++ b/llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
@@ -44,7 +44,7 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
 ; NOPAIR-NEXT:    addvl sp, sp, #-1
 ; NOPAIR-NEXT:    str z0, [sp] // 16-byte Folded Spill
 ; NOPAIR-NEXT:    bl __arm_sme_state
-; NOPAIR-NEXT:    and x19, x0, #0x1
+; NOPAIR-NEXT:    mov x19, x0
 ; NOPAIR-NEXT:    tbz w19, #0, .LBB0_2
 ; NOPAIR-NEXT:  // %bb.1:
 ; NOPAIR-NEXT:    smstop sm
@@ -126,7 +126,7 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
 ; PAIR-NEXT:    addvl sp, sp, #-1
 ; PAIR-NEXT:    str z0, [sp] // 16-byte Folded Spill
 ; PAIR-NEXT:    bl __arm_sme_state
-; PAIR-NEXT:    and x19, x0, #0x1
+; PAIR-NEXT:    mov x19, x0
 ; PAIR-NEXT:    tbz w19, #0, .LBB0_2
 ; PAIR-NEXT:  // %bb.1:
 ; PAIR-NEXT:    smstop sm
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index 4a52bf27a7591..759f3ee609e58 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -441,7 +441,7 @@ define float @frem_call_sm_compat(float %a, float %b) "aarch64_pstate_sm_compati
 ; CHECK-COMMON-NEXT:    str x19, [sp, #96] // 8-byte Folded Spill
 ; CHECK-COMMON-NEXT:    stp s0, s1, [sp, #8] // 8-byte Folded Spill
 ; CHECK-COMMON-NEXT:    bl __arm_sme_state
-; CHECK-COMMON-NEXT:    and x19, x0, #0x1
+; CHECK-COMMON-NEXT:    mov x19, x0
 ; CHECK-COMMON-NEXT:    tbz w19, #0, .LBB12_2
 ; CHECK-COMMON-NEXT:  // %bb.1:
 ; CHECK-COMMON-NEXT:    smstop sm
diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index e463e833bdbde..3f5e7e9f32a47 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -137,8 +137,10 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
 ; CHECK-NEXT:    str x9, [sp, #80] // 8-byte Folded Spill
 ; CHECK-NEXT:    stp x20, x19, [sp, #96] // 16-byte Folded Spill
 ; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    bl __arm_sme_state
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    mov x20, x0
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
 ; CHECK-NEXT:    stur x9, [x29, #-80]
@@ -147,8 +149,6 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
 ; CHECK-NEXT:    stur wzr, [x29, #-68]
 ; CHECK-NEXT:    sturh w8, [x29, #-72]
 ; CHECK-NEXT:    msr TPIDR2_EL0, x9
-; CHECK-NEXT:    bl __arm_sme_state
-; CHECK-NEXT:    and x20, x0, #0x1
 ; CHECK-NEXT:    tbz w20,...
[truncated]

Comment on lines +1181 to +1183
Register Reg = cast<RegisterSDNode>(F->getOperand(1))->getReg();
if (Reg.isPhysical())
UsedRegs.push_back(Reg);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adding bogus implict-defs of virtual registers used in the CopyFromReg glued to the SMSTART/STOP. From the comment above, I believe this is a bug, and this should only collect physregs.

Comment:

// The MachineInstr may also define physregs instead of virtregs. These
// physreg values can reach other instructions in different ways:
//
// 1. When there is a use of a Node value beyond the explicitly defined
// virtual registers, we emit a CopyFromReg for one of the implicitly
// defined physregs. This only happens when HasPhysRegOuts is true.
//
// 2. A CopyFromReg reading a physreg may be glued to this instruction.
//
// 3. A glued instruction may implicitly use a physreg.
//
// 4. A glued instruction may use a RegisterSDNode operand.
//
// Collect all the used physreg defs, and make sure that any unused physreg
// defs are marked as dead.

(I think this code corresponds the case 2)

; CHECK64-NEXT: sub x8, x8, x9
; CHECK64-NEXT: mov w2, w1
; CHECK64-NEXT: mov w8, w0
; CHECK64-NEXT: bl __arm_sme_state
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We can use mrs $var, SVCR when we have +sme which further cleans up code-gen (by not clobbering x0 and x1).

@MacDue MacDue changed the title [AArch64][SME] Use entry pstate.sm for conditional streaming-mode chnges [AArch64][SME] Use entry pstate.sm for conditional streaming-mode changes Aug 6, 2025
MacDue added 3 commits August 11, 2025 09:52
…nges

We only do conditional streaming mode changes in two cases:

 - Around calls in streaming-compatible functions that don't have a streaming body
 - At the entry/exit of streaming-compatible functions with a streaming body

In both cases, the condition depends on the entry pstate.sm value. Given
this, we don't need to emit calls to __arm_sme_state at every mode change.

This patch handles this by placing a "AArch64ISD::CALLER_IS_STREAMING"
node in the entry block and copying the result to a register. The register
is then used whenever we need to emit a conditional streaming mode change.
The "CALLER_IS_STREAMING" node expands to a call to  "__arm_sme_state" o
nly if (after SelectionDAG) the function is determined to have
streaming-mode changes.

This has two main advantages:

1. It allows back-to-back conditional smstart/stop pairs to be folded
2. It has the correct behaviour for EH landing pads
  - These are entered with pstate.sm = 0, and should switch mode based on
    the entry pstate.sm
  - Note: This is not fully implemented yet
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@MacDue MacDue merged commit d0c9599 into llvm:main Aug 12, 2025
9 checks passed
@MacDue MacDue deleted the sm_compat branch August 12, 2025 08:15
Copy link

@padipiee padipiee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants