7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This pass implements the SME ABI requirements for ZA state. This includes
10
- // implementing the lazy ZA state save schemes around calls.
10
+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
@@ -128,7 +128,7 @@ struct MachineSMEABI : public MachineFunctionPass {
128
128
129
129
void collectNeededZAStates (MachineFunction &MF, SMEAttrs);
130
130
void pickBundleZAStates (MachineFunction &MF);
131
- void insertStateChanges (MachineFunction &MF);
131
+ void insertStateChanges (MachineFunction &MF, bool IsAgnosticZA );
132
132
133
133
// Emission routines for private and shared ZA functions (using lazy saves).
134
134
void emitNewZAPrologue (MachineBasicBlock &MBB,
@@ -143,11 +143,46 @@ struct MachineSMEABI : public MachineFunctionPass {
143
143
void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
144
144
bool ClearTPIDR2);
145
145
146
+ // Emission routines for agnostic ZA functions.
147
+ void emitSetupFullZASave (MachineBasicBlock &MBB,
148
+ MachineBasicBlock::iterator MBBI,
149
+ LiveRegs PhysLiveRegs);
150
+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
151
+ MachineBasicBlock::iterator MBBI,
152
+ LiveRegs PhysLiveRegs, bool IsSave);
153
+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
154
+ MachineBasicBlock::iterator MBBI,
155
+ LiveRegs PhysLiveRegs);
156
+
146
157
void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
147
- ZAState From, ZAState To, LiveRegs PhysLiveRegs);
158
+ ZAState From, ZAState To, LiveRegs PhysLiveRegs,
159
+ bool IsAgnosticZA);
160
+
161
+ // Helpers for switching between lazy/full ZA save/restore routines.
162
+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
163
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
164
+ if (IsAgnosticZA)
165
+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
166
+ return emitSetupLazySave (MBB, MBBI);
167
+ }
168
+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
169
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
170
+ if (IsAgnosticZA)
171
+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
172
+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
173
+ }
174
+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
175
+ MachineBasicBlock::iterator MBBI,
176
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
177
+ if (IsAgnosticZA)
178
+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
179
+ return emitAllocateLazySaveBuffer (MBB, MBBI);
180
+ }
148
181
149
182
TPIDR2State getTPIDR2Block (MachineFunction &MF);
150
183
184
+ Register getAgnosticZABufferPtr (MachineFunction &MF);
185
+
151
186
private:
152
187
struct InstInfo {
153
188
ZAState NeededState{ZAState::ANY};
@@ -158,6 +193,7 @@ struct MachineSMEABI : public MachineFunctionPass {
158
193
struct BlockInfo {
159
194
ZAState FixedEntryState{ZAState::ANY};
160
195
SmallVector<InstInfo> Insts;
196
+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
161
197
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
162
198
};
163
199
@@ -167,6 +203,9 @@ struct MachineSMEABI : public MachineFunctionPass {
167
203
SmallVector<ZAState> BundleStates;
168
204
std::optional<TPIDR2State> TPIDR2Block;
169
205
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
206
+ Register AgnosticZABufferPtr = AArch64::NoRegister;
207
+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
208
+ bool HasFullZASaveRestore = false ;
170
209
} State;
171
210
172
211
EdgeBundles *Bundles = nullptr ;
@@ -175,7 +214,8 @@ struct MachineSMEABI : public MachineFunctionPass {
175
214
void MachineSMEABI::collectNeededZAStates (MachineFunction &MF,
176
215
SMEAttrs SMEFnAttrs) {
177
216
const TargetRegisterInfo &TRI = *MF.getSubtarget ().getRegisterInfo ();
178
- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
217
+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
218
+ SMEFnAttrs.hasZAState ()) &&
179
219
" Expected function to have ZA/ZT0 state!" );
180
220
181
221
State.Blocks .resize (MF.getNumBlockIDs ());
@@ -209,6 +249,7 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
209
249
210
250
Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
211
251
auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
252
+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
212
253
for (MachineInstr &MI : reverse (MBB)) {
213
254
MachineBasicBlock::iterator MBBI (MI);
214
255
LiveUnits.stepBackward (MI);
@@ -219,15 +260,20 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
219
260
// block setup.
220
261
if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
221
262
State.AfterSMEProloguePt = MBBI;
263
+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
222
264
}
265
+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
223
266
auto [NeededState, InsertPt] = getInstNeededZAState (
224
- TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface ());
267
+ TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface () ||
268
+ SMEFnAttrs.hasAgnosticZAInterface ());
225
269
assert ((InsertPt == MBBI ||
226
270
InsertPt->getOpcode () == AArch64::ADJCALLSTACKDOWN) &&
227
271
" Unexpected state change insertion point!" );
228
272
// TODO: Do something to avoid state changes where NZCV is live.
229
273
if (MBBI == FirstTerminatorInsertPt)
230
274
Block.PhysLiveRegsAtExit = PhysLiveRegs;
275
+ if (MBBI == FirstNonPhiInsertPt)
276
+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
231
277
if (NeededState != ZAState::ANY)
232
278
Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
233
279
}
@@ -294,7 +340,7 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
294
340
}
295
341
}
296
342
297
- void MachineSMEABI::insertStateChanges (MachineFunction &MF) {
343
+ void MachineSMEABI::insertStateChanges (MachineFunction &MF, bool IsAgnosticZA ) {
298
344
for (MachineBasicBlock &MBB : MF) {
299
345
BlockInfo &Block = State.Blocks [MBB.getNumber ()];
300
346
ZAState InState =
@@ -309,7 +355,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
309
355
for (auto &Inst : Block.Insts ) {
310
356
if (CurrentState != Inst.NeededState )
311
357
emitStateChange (MBB, Inst.InsertPt , CurrentState, Inst.NeededState ,
312
- Inst.PhysLiveRegs );
358
+ Inst.PhysLiveRegs , IsAgnosticZA );
313
359
CurrentState = Inst.NeededState ;
314
360
}
315
361
@@ -318,7 +364,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
318
364
319
365
if (CurrentState != OutState)
320
366
emitStateChange (MBB, MBB.getFirstTerminator (), CurrentState, OutState,
321
- Block.PhysLiveRegsAtExit );
367
+ Block.PhysLiveRegsAtExit , IsAgnosticZA );
322
368
}
323
369
}
324
370
@@ -571,10 +617,98 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
571
617
emitZeroZA (TII, DL, MBB, MBBI, /* Mask=*/ 0b11111111 );
572
618
}
573
619
620
+ Register MachineSMEABI::getAgnosticZABufferPtr (MachineFunction &MF) {
621
+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
622
+ return State.AgnosticZABufferPtr ;
623
+ if (auto BufferPtr =
624
+ MF.getInfo <AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer ();
625
+ BufferPtr != AArch64::NoRegister)
626
+ State.AgnosticZABufferPtr = BufferPtr;
627
+ else
628
+ State.AgnosticZABufferPtr =
629
+ MF.getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
630
+ return State.AgnosticZABufferPtr ;
631
+ }
632
+
633
+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
634
+ MachineBasicBlock::iterator MBBI,
635
+ LiveRegs PhysLiveRegs, bool IsSave) {
636
+ MachineFunction &MF = *MBB.getParent ();
637
+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
638
+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo ();
639
+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
640
+ MachineRegisterInfo &MRI = MF.getRegInfo ();
641
+
642
+ State.HasFullZASaveRestore = true ;
643
+ DebugLoc DL = getDebugLoc (MBB, MBBI);
644
+ Register BufferPtr = AArch64::X0;
645
+
646
+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
647
+
648
+ // Copy the buffer pointer into X0.
649
+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
650
+ .addReg (getAgnosticZABufferPtr (MF));
651
+
652
+ // Call __arm_sme_save/__arm_sme_restore.
653
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
654
+ .addReg (BufferPtr, RegState::Implicit)
655
+ .addExternalSymbol (IsSave ? " __arm_sme_save" : " __arm_sme_restore" )
656
+ .addRegMask (TRI.getCallPreservedMask (
657
+ MF,
658
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
659
+ }
660
+
661
+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
662
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
663
+ LiveRegs PhysLiveRegs) {
664
+ MachineFunction &MF = *MBB.getParent ();
665
+ MachineFrameInfo &MFI = MF.getFrameInfo ();
666
+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
667
+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
668
+ MachineRegisterInfo &MRI = MF.getRegInfo ();
669
+ auto *AFI = MF.getInfo <AArch64FunctionInfo>();
670
+
671
+ // Buffer already allocated in SelectionDAG.
672
+ if (AFI->getEarlyAllocSMESaveBuffer ())
673
+ return ;
674
+
675
+ DebugLoc DL = getDebugLoc (MBB, MBBI);
676
+ Register BufferPtr = getAgnosticZABufferPtr (MF);
677
+ Register BufferSize = MRI.createVirtualRegister (&AArch64::GPR64RegClass);
678
+
679
+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
680
+
681
+ // Calculate the SME state size.
682
+ {
683
+ const AArch64RegisterInfo *TRI = Subtarget.getRegisterInfo ();
684
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
685
+ .addExternalSymbol (" __arm_sme_state_size" )
686
+ .addReg (AArch64::X0, RegState::ImplicitDefine)
687
+ .addRegMask (TRI->getCallPreservedMask (
688
+ MF, CallingConv::
689
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
690
+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferSize)
691
+ .addReg (AArch64::X0);
692
+ }
693
+
694
+ // Allocate a buffer object of the size given __arm_sme_state_size.
695
+ {
696
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::SUBXrx64), AArch64::SP)
697
+ .addReg (AArch64::SP)
698
+ .addReg (BufferSize)
699
+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
700
+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
701
+ .addReg (AArch64::SP);
702
+
703
+ // We have just allocated a variable sized object, tell this to PEI.
704
+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
705
+ }
706
+ }
707
+
574
708
void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
575
709
MachineBasicBlock::iterator InsertPt,
576
710
ZAState From, ZAState To,
577
- LiveRegs PhysLiveRegs) {
711
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA ) {
578
712
579
713
// ZA not used.
580
714
if (From == ZAState::ANY || To == ZAState::ANY)
@@ -601,10 +735,11 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
601
735
}
602
736
603
737
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
604
- emitSetupLazySave (MBB, InsertPt);
738
+ emitZASave (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
605
739
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
606
- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
740
+ emitZARestore (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
607
741
else if (To == ZAState::OFF) {
742
+ assert (!IsAgnosticZA && " Should not turn ZA off in agnostic ZA function" );
608
743
// If we're exiting from the CALLER_DORMANT state that means this new ZA
609
744
// function did not touch ZA (so ZA was never turned on).
610
745
if (From != ZAState::CALLER_DORMANT)
@@ -627,7 +762,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
627
762
628
763
auto *AFI = MF.getInfo <AArch64FunctionInfo>();
629
764
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
630
- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
765
+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
766
+ !SMEFnAttrs.hasAgnosticZAInterface ())
631
767
return false ;
632
768
633
769
assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -636,20 +772,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
636
772
State = PassState{};
637
773
Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles ();
638
774
775
+ bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface ();
776
+
639
777
collectNeededZAStates (MF, SMEFnAttrs);
640
778
pickBundleZAStates (MF);
641
- insertStateChanges (MF);
779
+ insertStateChanges (MF, /* IsAgnosticZA= */ IsAgnosticZA );
642
780
643
781
// Allocate save buffer (if needed).
644
- if (State.TPIDR2Block .has_value ()) {
782
+ if (State.HasFullZASaveRestore || State. TPIDR2Block .has_value ()) {
645
783
if (State.AfterSMEProloguePt ) {
646
784
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
647
785
// entry block (due to the probing loop).
648
- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
649
- *State.AfterSMEProloguePt );
786
+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
787
+ *State.AfterSMEProloguePt ,
788
+ State.PhysLiveRegsAfterSMEPrologue ,
789
+ /* IsAgnosticZA=*/ IsAgnosticZA);
650
790
} else {
651
791
MachineBasicBlock &EntryBlock = MF.front ();
652
- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
792
+ emitAllocateZASaveBuffer (
793
+ EntryBlock, EntryBlock.getFirstNonPHI (),
794
+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry ,
795
+ /* IsAgnosticZA=*/ IsAgnosticZA);
653
796
}
654
797
}
655
798
0 commit comments