@@ -8244,53 +8244,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8244
8244
if (Subtarget->hasCustomCallingConv())
8245
8245
Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
8246
8246
8247
- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8248
- // will be expanded and stored in the static object later using a pseudonode.
8249
- if (Attrs.hasZAState()) {
8250
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8251
- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8252
- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8253
- DAG.getConstant(1, DL, MVT::i32));
8254
-
8255
- SDValue Buffer;
8256
- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8257
- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8258
- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8259
- } else {
8260
- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8261
- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8262
- DAG.getVTList(MVT::i64, MVT::Other),
8263
- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8264
- MFI.CreateVariableSizedObject(Align(16), nullptr);
8265
- }
8266
- Chain = DAG.getNode(
8267
- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8268
- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8269
- } else if (Attrs.hasAgnosticZAInterface()) {
8270
- // Call __arm_sme_state_size().
8271
- SDValue BufferSize =
8272
- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8273
- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8274
- Chain = BufferSize.getValue(1);
8275
-
8276
- SDValue Buffer;
8277
- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8278
- Buffer =
8279
- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8280
- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8281
- } else {
8282
- // Allocate space dynamically.
8283
- Buffer = DAG.getNode(
8284
- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8285
- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8286
- MFI.CreateVariableSizedObject(Align(16), nullptr);
8247
+ if (!Subtarget->useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8248
+ // Old SME ABI lowering (deprecated):
8249
+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8250
+ // will be expanded and stored in the static object later using a
8251
+ // pseudonode.
8252
+ if (Attrs.hasZAState()) {
8253
+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8254
+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8255
+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8256
+ DAG.getConstant(1, DL, MVT::i32));
8257
+ SDValue Buffer;
8258
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8259
+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8260
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8261
+ } else {
8262
+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8263
+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8264
+ DAG.getVTList(MVT::i64, MVT::Other),
8265
+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8266
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8267
+ }
8268
+ Chain = DAG.getNode(
8269
+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8270
+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8271
+ } else if (Attrs.hasAgnosticZAInterface()) {
8272
+ // Call __arm_sme_state_size().
8273
+ SDValue BufferSize =
8274
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8275
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8276
+ Chain = BufferSize.getValue(1);
8277
+ SDValue Buffer;
8278
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8279
+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8280
+ DAG.getVTList(MVT::i64, MVT::Other),
8281
+ {Chain, BufferSize});
8282
+ } else {
8283
+ // Allocate space dynamically.
8284
+ Buffer = DAG.getNode(
8285
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8286
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8287
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8288
+ }
8289
+ // Copy the value to a virtual register, and save that in FuncInfo.
8290
+ Register BufferPtr =
8291
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8292
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8293
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8287
8294
}
8288
-
8289
- // Copy the value to a virtual register, and save that in FuncInfo.
8290
- Register BufferPtr =
8291
- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8292
- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8293
- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8294
8295
}
8295
8296
8296
8297
if (CallConv == CallingConv::PreserveNone) {
@@ -8307,6 +8308,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8307
8308
}
8308
8309
}
8309
8310
8311
+ if (Subtarget->useNewSMEABILowering()) {
8312
+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8313
+ if (Attrs.isNewZT0())
8314
+ Chain = DAG.getNode(
8315
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8316
+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8317
+ DAG.getTargetConstant(0, DL, MVT::i32));
8318
+ }
8319
+
8310
8320
return Chain;
8311
8321
}
8312
8322
@@ -8871,14 +8881,12 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8871
8881
MachineFunction &MF = DAG.getMachineFunction();
8872
8882
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8873
8883
FuncInfo->setSMESaveBufferUsed();
8874
-
8875
8884
TargetLowering::ArgListTy Args;
8876
8885
TargetLowering::ArgListEntry Entry;
8877
8886
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8878
8887
Entry.Node =
8879
8888
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8880
8889
Args.push_back(Entry);
8881
-
8882
8890
SDValue Callee =
8883
8891
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8884
8892
TLI.getPointerTy(DAG.getDataLayout()));
@@ -9001,6 +9009,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9001
9009
if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall())
9002
9010
CSInfo = MachineFunction::CallSiteInfo(*CB);
9003
9011
9012
+ // Determine whether we need any streaming mode changes.
9013
+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9014
+
9004
9015
// Check callee args/returns for SVE registers and set calling convention
9005
9016
// accordingly.
9006
9017
if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
@@ -9014,14 +9025,26 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9014
9025
CallConv = CallingConv::AArch64_SVE_VectorCall;
9015
9026
}
9016
9027
9028
+ bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
9029
+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
9030
+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9031
+ // TODO: Handle agnostic ZA functions.
9032
+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9033
+ return std::nullopt;
9034
+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
9035
+ return std::nullopt;
9036
+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
9037
+ : AArch64ISD::INOUT_ZA_USE;
9038
+ }();
9039
+
9017
9040
if (IsTailCall) {
9018
9041
// Check if it's really possible to do a tail call.
9019
9042
IsTailCall = isEligibleForTailCallOptimization(CLI);
9020
9043
9021
9044
// A sibling call is one where we're under the usual C ABI and not planning
9022
9045
// to change that but can still do a tail call:
9023
- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
9024
- CallConv != CallingConv::SwiftTail)
9046
+ if (!ZAMarkerNode.has_value() && !TailCallOpt && IsTailCall &&
9047
+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
9025
9048
IsSibCall = true;
9026
9049
9027
9050
if (IsTailCall)
@@ -9073,9 +9096,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9073
9096
assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
9074
9097
}
9075
9098
9076
- // Determine whether we need any streaming mode changes.
9077
- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9078
-
9079
9099
auto DescribeCallsite =
9080
9100
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
9081
9101
R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -9089,7 +9109,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9089
9109
return R;
9090
9110
};
9091
9111
9092
- bool RequiresLazySave = CallAttrs.requiresLazySave();
9112
+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
9093
9113
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9094
9114
if (RequiresLazySave) {
9095
9115
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9171,10 +9191,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9171
9191
AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
9172
9192
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
9173
9193
9174
- // Adjust the stack pointer for the new arguments...
9194
+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
9175
9195
// These operations are automatically eliminated by the prolog/epilog pass
9176
- if (!IsSibCall)
9196
+ assert((!IsSibCall || !ZAMarkerNode.has_value()) &&
9197
+ "ZA markers require CALLSEQ_START");
9198
+ if (!IsSibCall) {
9177
9199
Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9200
+ if (ZAMarkerNode) {
9201
+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9202
+ // using a chain can result in incorrect scheduling. The markers referer
9203
+ // to the position just before the CALLSEQ_START (though occur after as
9204
+ // CALLSEQ_START lacks in-glue).
9205
+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9206
+ {Chain, Chain.getValue(1)});
9207
+ }
9208
+ }
9178
9209
9179
9210
SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
9180
9211
getPointerTy(DAG.getDataLayout()));
@@ -9646,7 +9677,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9646
9677
}
9647
9678
}
9648
9679
9649
- if (CallAttrs.requiresEnablingZAAfterCall())
9680
+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
9650
9681
// Unconditionally resume ZA.
9651
9682
Result = DAG.getNode(
9652
9683
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9667,7 +9698,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9667
9698
SDValue TPIDR2_EL0 = DAG.getNode(
9668
9699
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
9669
9700
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9670
-
9671
9701
// Copy the address of the TPIDR2 block into X0 before 'calling' the
9672
9702
// RESTORE_ZA pseudo.
9673
9703
SDValue Glue;
@@ -9679,7 +9709,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9679
9709
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
9680
9710
{Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
9681
9711
RestoreRoutine, RegMask, Result.getValue(1)});
9682
-
9683
9712
// Finally reset the TPIDR2_EL0 register to 0.
9684
9713
Result = DAG.getNode(
9685
9714
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments