Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,9 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
if (Node->getValueType(Node->getNumValues()-1) == MVT::Glue) {
for (SDNode *F = Node->getGluedUser(); F; F = F->getGluedUser()) {
if (F->getOpcode() == ISD::CopyFromReg) {
UsedRegs.push_back(cast<RegisterSDNode>(F->getOperand(1))->getReg());
Register Reg = cast<RegisterSDNode>(F->getOperand(1))->getReg();
if (Reg.isPhysical())
UsedRegs.push_back(Reg);
Comment on lines +1181 to +1183
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adding bogus implict-defs of virtual registers used in the CopyFromReg glued to the SMSTART/STOP. From the comment above, I believe this is a bug, and this should only collect physregs.

Comment:

// The MachineInstr may also define physregs instead of virtregs. These
// physreg values can reach other instructions in different ways:
//
// 1. When there is a use of a Node value beyond the explicitly defined
// virtual registers, we emit a CopyFromReg for one of the implicitly
// defined physregs. This only happens when HasPhysRegOuts is true.
//
// 2. A CopyFromReg reading a physreg may be glued to this instruction.
//
// 3. A glued instruction may implicitly use a physreg.
//
// 4. A glued instruction may use a RegisterSDNode operand.
//
// Collect all the used physreg defs, and make sure that any unused physreg
// defs are marked as dead.

(I think this code corresponds the case 2)

continue;
} else if (F->getOpcode() == ISD::CopyToReg) {
// Skip CopyToReg nodes that are internal to the glue chain.
Expand Down
90 changes: 60 additions & 30 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3101,6 +3101,31 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
return BB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI,
MachineBasicBlock *BB) const {
MachineFunction *MF = BB->getParent();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
Register ResultReg = MI.getOperand(0).getReg();
if (FuncInfo->isPStateSMRegUsed()) {
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
.addExternalSymbol("__arm_sme_state")
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg)
.addReg(AArch64::X0);
} else {
assert(MI.getMF()->getRegInfo().use_empty(ResultReg) &&
"Expected no users of the entry pstate.sm!");
}
MI.eraseFromParent();
return BB;
}

// Helper function to find the instruction that defined a virtual register.
// If unable to find such instruction, returns nullptr.
static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI,
Expand Down Expand Up @@ -3216,6 +3241,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitAllocateSMESaveBuffer(MI, BB);
case AArch64::GetSMESaveSize:
return EmitGetSMESaveSize(MI, BB);
case AArch64::EntryPStateSM:
return EmitEntryPStateSM(MI, BB);
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
Expand Down Expand Up @@ -8133,19 +8160,26 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
}
assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());

if (Attrs.hasStreamingCompatibleInterface()) {
SDValue EntryPStateSM =
DAG.getNode(AArch64ISD::ENTRY_PSTATE_SM, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain});

// Copy the value to a virtual register, and save that in FuncInfo.
Register EntryPStateSMReg =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
Chain = DAG.getCopyToReg(EntryPStateSM.getValue(1), DL, EntryPStateSMReg,
EntryPStateSM);
FuncInfo->setPStateSMReg(EntryPStateSMReg);
}

// Insert the SMSTART if this is a locally streaming function and
// make sure it is Glued to the last CopyFromReg value.
if (IsLocallyStreaming) {
SDValue PStateSM;
if (Attrs.hasStreamingCompatibleInterface()) {
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
Register Reg = MF.getRegInfo().createVirtualRegister(
getRegClassFor(PStateSM.getValueType().getSimpleVT()));
FuncInfo->setPStateSMReg(Reg);
Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
if (Attrs.hasStreamingCompatibleInterface())
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
AArch64SME::IfCallerIsNonStreaming, PStateSM);
} else
AArch64SME::IfCallerIsNonStreaming);
else
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
AArch64SME::Always);

Expand Down Expand Up @@ -8836,8 +8870,7 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
bool Enable, SDValue Chain,
SDValue InGlue,
unsigned Condition,
SDValue PStateSM) const {
unsigned Condition) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);
Expand All @@ -8849,9 +8882,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
SmallVector<SDValue> Ops = {Chain, MSROp};
unsigned Opcode;
if (Condition != AArch64SME::Always) {
FuncInfo->setPStateSMRegUsed(true);
Register PStateReg = FuncInfo->getPStateSMReg();
assert(PStateReg.isValid() && "PStateSM Register is invalid");
SDValue PStateSM =
DAG.getCopyFromReg(Chain, DL, PStateReg, MVT::i64, InGlue);
// Use chain and glue from the CopyFromReg.
Ops[0] = PStateSM.getValue(1);
InGlue = PStateSM.getValue(2);
SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
assert(PStateSM && "PStateSM should be defined");
Ops.push_back(ConditionOp);
Ops.push_back(PStateSM);
} else {
Expand Down Expand Up @@ -9126,15 +9166,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
/*IsSave=*/true);
}

SDValue PStateSM;
bool RequiresSMChange = CallAttrs.requiresSMChange();
if (RequiresSMChange) {
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
PStateSM = DAG.getConstant(1, DL, MVT::i64);
else if (CallAttrs.caller().hasNonStreamingInterface())
PStateSM = DAG.getConstant(0, DL, MVT::i64);
else
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
OptimizationRemarkEmitter ORE(&MF.getFunction());
ORE.emit([&]() {
auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
Expand Down Expand Up @@ -9449,9 +9482,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
InGlue = Chain.getValue(1);
}

SDValue NewChain = changeStreamingMode(
DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
getSMToggleCondition(CallAttrs), PStateSM);
SDValue NewChain =
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
Chain, InGlue, getSMToggleCondition(CallAttrs));
Chain = NewChain.getValue(0);
InGlue = NewChain.getValue(1);
}
Expand Down Expand Up @@ -9635,10 +9668,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
InGlue = Result.getValue(Result->getNumValues() - 1);

if (RequiresSMChange) {
assert(PStateSM && "Expected a PStateSM to be set");
Result = changeStreamingMode(
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
getSMToggleCondition(CallAttrs), PStateSM);
getSMToggleCondition(CallAttrs));

if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
InGlue = Result.getValue(1);
Expand Down Expand Up @@ -9804,14 +9836,11 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
if (FuncAttrs.hasStreamingCompatibleInterface()) {
Register Reg = FuncInfo->getPStateSMReg();
assert(Reg.isValid() && "PStateSM Register is invalid");
SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
if (FuncAttrs.hasStreamingCompatibleInterface())
Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
/*Glue*/ SDValue(),
AArch64SME::IfCallerIsNonStreaming, PStateSM);
} else
AArch64SME::IfCallerIsNonStreaming);
else
Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
/*Glue*/ SDValue(), AArch64SME::Always);
Glue = Chain.getValue(1);
Expand Down Expand Up @@ -28196,6 +28225,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
case Intrinsic::aarch64_sme_in_streaming_mode: {
SDLoc DL(N);
SDValue Chain = DAG.getEntryNode();

SDValue RuntimePStateSM =
getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
Results.push_back(
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitEntryPStateSM(MachineInstr &MI,
MachineBasicBlock *BB) const;

/// Replace (0, vreg) discriminator components with the operands of blend
/// or with (immediate, NoRegister) when possible.
Expand Down Expand Up @@ -523,8 +525,8 @@ class AArch64TargetLowering : public TargetLowering {
/// node. \p Condition should be one of the enum values from
/// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
SDValue Chain, SDValue InGlue, unsigned Condition,
SDValue PStateSM = SDValue()) const;
SDValue Chain, SDValue InGlue,
unsigned Condition) const;

bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;

// true if PStateSMReg is used.
bool PStateSMRegUsed = false;

// Holds a pointer to a buffer that is large enough to represent
// all SME ZA state and any additional state required by the
// __arm_sme_save/restore support routines.
Expand Down Expand Up @@ -274,6 +277,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

unsigned isPStateSMRegUsed() const { return PStateSMRegUsed; };
void setPStateSMRegUsed(bool Used = true) { PStateSMRegUsed = Used; };

int64_t getVGIdx() const { return VGIdx; };
void setVGIdx(unsigned Idx) { VGIdx = Idx; };

Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
def AArch64CoalescerBarrier
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;

def AArch64EntryPStateSM
: SDNode<"AArch64ISD::ENTRY_PSTATE_SM", SDTypeProfile<1, 0,
[SDTCisInt<0>]>, [SDNPHasChain, SDNPSideEffect]>;

let usesCustomInserter = 1 in {
def EntryPStateSM : Pseudo<(outs GPR64:$is_streaming), (ins), []>, Sched<[]> {}
}
def : Pat<(i64 (AArch64EntryPStateSM)), (EntryPStateSM)>;

def AArch64VGSave : SDNode<"AArch64ISD::VG_SAVE", SDTypeProfile<0, 0, []>,
[SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;

Expand Down
12 changes: 3 additions & 9 deletions llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,10 @@ static bool isMatchingStartStopPair(const MachineInstr *MI1,
if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
return false;

// This optimisation is unlikely to happen in practice for conditional
// smstart/smstop pairs as the virtual registers for pstate.sm will always
// be different.
// TODO: For this optimisation to apply to conditional smstart/smstop,
// this pass will need to do more work to remove redundant calls to
// __arm_sme_state.

// Only consider conditional start/stop pairs which read the same register
// holding the original value of pstate.sm, as some conditional start/stops
// require the state on entry to the function.
// holding the original value of pstate.sm. This is somewhat over conservative
// as all conditional streaming mode changes only look at the state on entry
// to the function.
if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
Register Reg1 = MI1->getOperand(3).getReg();
Register Reg2 = MI2->getOperand(3).getReg();
Expand Down
28 changes: 13 additions & 15 deletions llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
Original file line number Diff line number Diff line change
Expand Up @@ -150,42 +150,40 @@ define i64 @streaming_compatible_agnostic_caller_nonstreaming_private_za_callee(
; CHECK-NEXT: add x29, sp, #64
; CHECK-NEXT: stp x20, x19, [sp, #96] // 16-byte Folded Spill
; CHECK-NEXT: mov x8, x0
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: bl __arm_sme_state_size
; CHECK-NEXT: sub sp, sp, x0
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: mov x0, x19
; CHECK-NEXT: mov x20, sp
; CHECK-NEXT: mov x0, x20
; CHECK-NEXT: bl __arm_sme_save
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x20, x0, #0x1
; CHECK-NEXT: tbz w20, #0, .LBB5_2
; CHECK-NEXT: tbz w19, #0, .LBB5_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB5_2:
; CHECK-NEXT: mov x0, x8
; CHECK-NEXT: bl private_za_decl
; CHECK-NEXT: mov x2, x0
; CHECK-NEXT: tbz w20, #0, .LBB5_4
; CHECK-NEXT: mov x1, x0
; CHECK-NEXT: tbz w19, #0, .LBB5_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB5_4:
; CHECK-NEXT: mov x0, x19
; CHECK-NEXT: mov x0, x20
; CHECK-NEXT: bl __arm_sme_restore
; CHECK-NEXT: mov x0, x19
; CHECK-NEXT: mov x0, x20
; CHECK-NEXT: bl __arm_sme_save
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x20, x0, #0x1
; CHECK-NEXT: tbz w20, #0, .LBB5_6
; CHECK-NEXT: tbz w19, #0, .LBB5_6
; CHECK-NEXT: // %bb.5:
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB5_6:
; CHECK-NEXT: mov x0, x2
; CHECK-NEXT: mov x0, x1
; CHECK-NEXT: bl private_za_decl
; CHECK-NEXT: mov x1, x0
; CHECK-NEXT: tbz w20, #0, .LBB5_8
; CHECK-NEXT: tbz w19, #0, .LBB5_8
; CHECK-NEXT: // %bb.7:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB5_8:
; CHECK-NEXT: mov x0, x19
; CHECK-NEXT: mov x0, x20
; CHECK-NEXT: bl __arm_sme_restore
; CHECK-NEXT: mov x0, x1
; CHECK-NEXT: sub sp, x29, #64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ define void @streaming_compatible() #0 {
; CHECK-NEXT: bl __arm_get_current_vg
; CHECK-NEXT: stp x0, x19, [sp, #72] // 16-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: tbz w19, #0, .LBB0_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstop sm
Expand Down Expand Up @@ -57,7 +57,7 @@ define void @streaming_compatible_arg(float %f) #0 {
; CHECK-NEXT: stp x0, x19, [sp, #88] // 16-byte Folded Spill
; CHECK-NEXT: str s0, [sp, #12] // 4-byte Folded Spill
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: tbz w19, #0, .LBB1_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstop sm
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; NOPAIR-NEXT: addvl sp, sp, #-1
; NOPAIR-NEXT: str z0, [sp] // 16-byte Folded Spill
; NOPAIR-NEXT: bl __arm_sme_state
; NOPAIR-NEXT: and x19, x0, #0x1
; NOPAIR-NEXT: mov x19, x0
; NOPAIR-NEXT: tbz w19, #0, .LBB0_2
; NOPAIR-NEXT: // %bb.1:
; NOPAIR-NEXT: smstop sm
Expand Down Expand Up @@ -126,7 +126,7 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; PAIR-NEXT: addvl sp, sp, #-1
; PAIR-NEXT: str z0, [sp] // 16-byte Folded Spill
; PAIR-NEXT: bl __arm_sme_state
; PAIR-NEXT: and x19, x0, #0x1
; PAIR-NEXT: mov x19, x0
; PAIR-NEXT: tbz w19, #0, .LBB0_2
; PAIR-NEXT: // %bb.1:
; PAIR-NEXT: smstop sm
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ define float @frem_call_sm_compat(float %a, float %b) "aarch64_pstate_sm_compati
; CHECK-COMMON-NEXT: str x19, [sp, #96] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: bl __arm_sme_state
; CHECK-COMMON-NEXT: and x19, x0, #0x1
; CHECK-COMMON-NEXT: mov x19, x0
; CHECK-COMMON-NEXT: tbz w19, #0, .LBB12_2
; CHECK-COMMON-NEXT: // %bb.1:
; CHECK-COMMON-NEXT: smstop sm
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
; CHECK-NEXT: str x9, [sp, #80] // 8-byte Folded Spill
; CHECK-NEXT: stp x20, x19, [sp, #96] // 16-byte Folded Spill
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: mov x20, x0
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: stur x9, [x29, #-80]
Expand All @@ -147,8 +149,6 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
; CHECK-NEXT: stur wzr, [x29, #-68]
; CHECK-NEXT: sturh w8, [x29, #-72]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x20, x0, #0x1
; CHECK-NEXT: tbz w20, #0, .LBB3_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstop sm
Expand Down
Loading