7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This pass implements the SME ABI requirements for ZA state. This includes
10
- // implementing the lazy ZA state save schemes around calls.
10
+ // implementing the lazy (and agnostic) ZA state save schemes around calls.
11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
@@ -128,7 +128,7 @@ struct MachineSMEABI : public MachineFunctionPass {
128
128
129
129
void collectNeededZAStates (MachineFunction &MF, SMEAttrs);
130
130
void pickBundleZAStates (MachineFunction &MF);
131
- void insertStateChanges (MachineFunction &MF);
131
+ void insertStateChanges (MachineFunction &MF, bool IsAgnosticZA );
132
132
133
133
// Emission routines for private and shared ZA functions (using lazy saves).
134
134
void emitNewZAPrologue (MachineBasicBlock &MBB,
@@ -143,11 +143,46 @@ struct MachineSMEABI : public MachineFunctionPass {
143
143
void emitZAOff (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
144
144
bool ClearTPIDR2);
145
145
146
+ // Emission routines for agnostic ZA functions.
147
+ void emitSetupFullZASave (MachineBasicBlock &MBB,
148
+ MachineBasicBlock::iterator MBBI,
149
+ LiveRegs PhysLiveRegs);
150
+ void emitFullZASaveRestore (MachineBasicBlock &MBB,
151
+ MachineBasicBlock::iterator MBBI,
152
+ LiveRegs PhysLiveRegs, bool IsSave);
153
+ void emitAllocateFullZASaveBuffer (MachineBasicBlock &MBB,
154
+ MachineBasicBlock::iterator MBBI,
155
+ LiveRegs PhysLiveRegs);
156
+
146
157
void emitStateChange (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
147
- ZAState From, ZAState To, LiveRegs PhysLiveRegs);
158
+ ZAState From, ZAState To, LiveRegs PhysLiveRegs,
159
+ bool IsAgnosticZA);
160
+
161
+ // Helpers for switching between lazy/full ZA save/restore routines.
162
+ void emitZASave (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
163
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
164
+ if (IsAgnosticZA)
165
+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ true );
166
+ return emitSetupLazySave (MBB, MBBI);
167
+ }
168
+ void emitZARestore (MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
169
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
170
+ if (IsAgnosticZA)
171
+ return emitFullZASaveRestore (MBB, MBBI, PhysLiveRegs, /* IsSave=*/ false );
172
+ return emitRestoreLazySave (MBB, MBBI, PhysLiveRegs);
173
+ }
174
+ void emitAllocateZASaveBuffer (MachineBasicBlock &MBB,
175
+ MachineBasicBlock::iterator MBBI,
176
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA) {
177
+ if (IsAgnosticZA)
178
+ return emitAllocateFullZASaveBuffer (MBB, MBBI, PhysLiveRegs);
179
+ return emitAllocateLazySaveBuffer (MBB, MBBI);
180
+ }
148
181
149
182
TPIDR2State getTPIDR2Block (MachineFunction &MF);
150
183
184
+ Register getAgnosticZABufferPtr (MachineFunction &MF);
185
+
151
186
private:
152
187
struct InstInfo {
153
188
ZAState NeededState{ZAState::ANY};
@@ -158,6 +193,7 @@ struct MachineSMEABI : public MachineFunctionPass {
158
193
struct BlockInfo {
159
194
ZAState FixedEntryState{ZAState::ANY};
160
195
SmallVector<InstInfo> Insts;
196
+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
161
197
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
162
198
};
163
199
@@ -167,6 +203,9 @@ struct MachineSMEABI : public MachineFunctionPass {
167
203
SmallVector<ZAState> BundleStates;
168
204
std::optional<TPIDR2State> TPIDR2Block;
169
205
std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
206
+ Register AgnosticZABufferPtr = AArch64::NoRegister;
207
+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
208
+ bool HasFullZASaveRestore = false ;
170
209
} State;
171
210
172
211
EdgeBundles *Bundles = nullptr ;
@@ -175,7 +214,8 @@ struct MachineSMEABI : public MachineFunctionPass {
175
214
void MachineSMEABI::collectNeededZAStates (MachineFunction &MF,
176
215
SMEAttrs SMEFnAttrs) {
177
216
const TargetRegisterInfo &TRI = *MF.getSubtarget ().getRegisterInfo ();
178
- assert ((SMEFnAttrs.hasZT0State () || SMEFnAttrs.hasZAState ()) &&
217
+ assert ((SMEFnAttrs.hasAgnosticZAInterface () || SMEFnAttrs.hasZT0State () ||
218
+ SMEFnAttrs.hasZAState ()) &&
179
219
" Expected function to have ZA/ZT0 state!" );
180
220
181
221
State.Blocks .resize (MF.getNumBlockIDs ());
@@ -209,6 +249,7 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
209
249
210
250
Block.PhysLiveRegsAtExit = GetPhysLiveRegs ();
211
251
auto FirstTerminatorInsertPt = MBB.getFirstTerminator ();
252
+ auto FirstNonPhiInsertPt = MBB.getFirstNonPHI ();
212
253
for (MachineInstr &MI : reverse (MBB)) {
213
254
MachineBasicBlock::iterator MBBI (MI);
214
255
LiveUnits.stepBackward (MI);
@@ -219,15 +260,20 @@ void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
219
260
// block setup.
220
261
if (MI.getOpcode () == AArch64::SMEStateAllocPseudo) {
221
262
State.AfterSMEProloguePt = MBBI;
263
+ State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
222
264
}
265
+ // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
223
266
auto [NeededState, InsertPt] = getInstNeededZAState (
224
- TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface ());
267
+ TRI, MI, /* ZALiveAtReturn=*/ SMEFnAttrs.hasSharedZAInterface () ||
268
+ SMEFnAttrs.hasAgnosticZAInterface ());
225
269
assert ((InsertPt == MBBI ||
226
270
InsertPt->getOpcode () == AArch64::ADJCALLSTACKDOWN) &&
227
271
" Unexpected state change insertion point!" );
228
272
// TODO: Do something to avoid state changes where NZCV is live.
229
273
if (MBBI == FirstTerminatorInsertPt)
230
274
Block.PhysLiveRegsAtExit = PhysLiveRegs;
275
+ if (MBBI == FirstNonPhiInsertPt)
276
+ Block.PhysLiveRegsAtEntry = PhysLiveRegs;
231
277
if (NeededState != ZAState::ANY)
232
278
Block.Insts .push_back ({NeededState, InsertPt, PhysLiveRegs});
233
279
}
@@ -294,7 +340,7 @@ void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
294
340
}
295
341
}
296
342
297
- void MachineSMEABI::insertStateChanges (MachineFunction &MF) {
343
+ void MachineSMEABI::insertStateChanges (MachineFunction &MF, bool IsAgnosticZA ) {
298
344
for (MachineBasicBlock &MBB : MF) {
299
345
BlockInfo &Block = State.Blocks [MBB.getNumber ()];
300
346
ZAState InState =
@@ -309,7 +355,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
309
355
for (auto &Inst : Block.Insts ) {
310
356
if (CurrentState != Inst.NeededState )
311
357
emitStateChange (MBB, Inst.InsertPt , CurrentState, Inst.NeededState ,
312
- Inst.PhysLiveRegs );
358
+ Inst.PhysLiveRegs , IsAgnosticZA );
313
359
CurrentState = Inst.NeededState ;
314
360
}
315
361
@@ -318,7 +364,7 @@ void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
318
364
319
365
if (CurrentState != OutState)
320
366
emitStateChange (MBB, MBB.getFirstTerminator (), CurrentState, OutState,
321
- Block.PhysLiveRegsAtExit );
367
+ Block.PhysLiveRegsAtExit , IsAgnosticZA );
322
368
}
323
369
}
324
370
@@ -573,10 +619,98 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
573
619
emitZeroZA (TII, DL, MBB, MBBI, /* Mask=*/ 0b11111111 );
574
620
}
575
621
622
+ Register MachineSMEABI::getAgnosticZABufferPtr (MachineFunction &MF) {
623
+ if (State.AgnosticZABufferPtr != AArch64::NoRegister)
624
+ return State.AgnosticZABufferPtr ;
625
+ if (auto BufferPtr =
626
+ MF.getInfo <AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer ();
627
+ BufferPtr != AArch64::NoRegister)
628
+ State.AgnosticZABufferPtr = BufferPtr;
629
+ else
630
+ State.AgnosticZABufferPtr =
631
+ MF.getRegInfo ().createVirtualRegister (&AArch64::GPR64RegClass);
632
+ return State.AgnosticZABufferPtr ;
633
+ }
634
+
635
+ void MachineSMEABI::emitFullZASaveRestore (MachineBasicBlock &MBB,
636
+ MachineBasicBlock::iterator MBBI,
637
+ LiveRegs PhysLiveRegs, bool IsSave) {
638
+ MachineFunction &MF = *MBB.getParent ();
639
+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
640
+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo ();
641
+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
642
+ MachineRegisterInfo &MRI = MF.getRegInfo ();
643
+
644
+ State.HasFullZASaveRestore = true ;
645
+ DebugLoc DL = getDebugLoc (MBB, MBBI);
646
+ Register BufferPtr = AArch64::X0;
647
+
648
+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
649
+
650
+ // Copy the buffer pointer into X0.
651
+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
652
+ .addReg (getAgnosticZABufferPtr (MF));
653
+
654
+ // Call __arm_sme_save/__arm_sme_restore.
655
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
656
+ .addReg (BufferPtr, RegState::Implicit)
657
+ .addExternalSymbol (IsSave ? " __arm_sme_save" : " __arm_sme_restore" )
658
+ .addRegMask (TRI.getCallPreservedMask (
659
+ MF,
660
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
661
+ }
662
+
663
+ void MachineSMEABI::emitAllocateFullZASaveBuffer (
664
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
665
+ LiveRegs PhysLiveRegs) {
666
+ MachineFunction &MF = *MBB.getParent ();
667
+ MachineFrameInfo &MFI = MF.getFrameInfo ();
668
+ auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
669
+ const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
670
+ MachineRegisterInfo &MRI = MF.getRegInfo ();
671
+ auto *AFI = MF.getInfo <AArch64FunctionInfo>();
672
+
673
+ // Buffer already allocated in SelectionDAG.
674
+ if (AFI->getEarlyAllocSMESaveBuffer ())
675
+ return ;
676
+
677
+ DebugLoc DL = getDebugLoc (MBB, MBBI);
678
+ Register BufferPtr = getAgnosticZABufferPtr (MF);
679
+ Register BufferSize = MRI.createVirtualRegister (&AArch64::GPR64RegClass);
680
+
681
+ ScopedPhysRegSave ScopedPhysRegSave (MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
682
+
683
+ // Calculate the SME state size.
684
+ {
685
+ const AArch64RegisterInfo *TRI = Subtarget.getRegisterInfo ();
686
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::BL))
687
+ .addExternalSymbol (" __arm_sme_state_size" )
688
+ .addReg (AArch64::X0, RegState::ImplicitDefine)
689
+ .addRegMask (TRI->getCallPreservedMask (
690
+ MF, CallingConv::
691
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
692
+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferSize)
693
+ .addReg (AArch64::X0);
694
+ }
695
+
696
+ // Allocate a buffer object of the size given __arm_sme_state_size.
697
+ {
698
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::SUBXrx64), AArch64::SP)
699
+ .addReg (AArch64::SP)
700
+ .addReg (BufferSize)
701
+ .addImm (AArch64_AM::getArithExtendImm (AArch64_AM::UXTX, 0 ));
702
+ BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::COPY), BufferPtr)
703
+ .addReg (AArch64::SP);
704
+
705
+ // We have just allocated a variable sized object, tell this to PEI.
706
+ MFI.CreateVariableSizedObject (Align (16 ), nullptr );
707
+ }
708
+ }
709
+
576
710
void MachineSMEABI::emitStateChange (MachineBasicBlock &MBB,
577
711
MachineBasicBlock::iterator InsertPt,
578
712
ZAState From, ZAState To,
579
- LiveRegs PhysLiveRegs) {
713
+ LiveRegs PhysLiveRegs, bool IsAgnosticZA ) {
580
714
581
715
// ZA not used.
582
716
if (From == ZAState::ANY || To == ZAState::ANY)
@@ -603,10 +737,11 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
603
737
}
604
738
605
739
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
606
- emitSetupLazySave (MBB, InsertPt);
740
+ emitZASave (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
607
741
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
608
- emitRestoreLazySave (MBB, InsertPt, PhysLiveRegs);
742
+ emitZARestore (MBB, InsertPt, PhysLiveRegs, IsAgnosticZA );
609
743
else if (To == ZAState::OFF) {
744
+ assert (!IsAgnosticZA && " Should not turn ZA off in agnostic ZA function" );
610
745
// If we're exiting from the CALLER_DORMANT state that means this new ZA
611
746
// function did not touch ZA (so ZA was never turned on).
612
747
if (From != ZAState::CALLER_DORMANT)
@@ -629,7 +764,8 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
629
764
630
765
auto *AFI = MF.getInfo <AArch64FunctionInfo>();
631
766
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs ();
632
- if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State ())
767
+ if (!SMEFnAttrs.hasZAState () && !SMEFnAttrs.hasZT0State () &&
768
+ !SMEFnAttrs.hasAgnosticZAInterface ())
633
769
return false ;
634
770
635
771
assert (MF.getRegInfo ().isSSA () && " Expected to be run on SSA form!" );
@@ -638,20 +774,27 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
638
774
State = PassState{};
639
775
Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles ();
640
776
777
+ bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface ();
778
+
641
779
collectNeededZAStates (MF, SMEFnAttrs);
642
780
pickBundleZAStates (MF);
643
- insertStateChanges (MF);
781
+ insertStateChanges (MF, /* IsAgnosticZA= */ IsAgnosticZA );
644
782
645
783
// Allocate save buffer (if needed).
646
- if (State.TPIDR2Block .has_value ()) {
784
+ if (State.HasFullZASaveRestore || State. TPIDR2Block .has_value ()) {
647
785
if (State.AfterSMEProloguePt ) {
648
786
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
649
787
// entry block (due to the probing loop).
650
- emitAllocateLazySaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
651
- *State.AfterSMEProloguePt );
788
+ emitAllocateZASaveBuffer (*(*State.AfterSMEProloguePt )->getParent (),
789
+ *State.AfterSMEProloguePt ,
790
+ State.PhysLiveRegsAfterSMEPrologue ,
791
+ /* IsAgnosticZA=*/ IsAgnosticZA);
652
792
} else {
653
793
MachineBasicBlock &EntryBlock = MF.front ();
654
- emitAllocateLazySaveBuffer (EntryBlock, EntryBlock.getFirstNonPHI ());
794
+ emitAllocateZASaveBuffer (
795
+ EntryBlock, EntryBlock.getFirstNonPHI (),
796
+ State.Blocks [EntryBlock.getNumber ()].PhysLiveRegsAtEntry ,
797
+ /* IsAgnosticZA=*/ IsAgnosticZA);
655
798
}
656
799
}
657
800
0 commit comments