diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp index 80a7529002ac9..e64b35d230d48 100644 --- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -310,7 +310,14 @@ class WaitcntBrackets { bool counterOutOfOrder(InstCounterType T) const; void simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const; void simplifyWaitcnt(InstCounterType T, unsigned &Count) const; - void determineWait(InstCounterType T, int RegNo, AMDGPU::Waitcnt &Wait) const; + + void determineWait(InstCounterType T, RegInterval Interval, + AMDGPU::Waitcnt &Wait) const; + void determineWait(InstCounterType T, int RegNo, + AMDGPU::Waitcnt &Wait) const { + determineWait(T, {RegNo, RegNo + 1}, Wait); + } + void applyWaitcnt(const AMDGPU::Waitcnt &Wait); void applyWaitcnt(InstCounterType T, unsigned Count); void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI, @@ -345,16 +352,22 @@ class WaitcntBrackets { LastFlat[DS_CNT] = ScoreUBs[DS_CNT]; } - // Return true if there might be pending writes to the specified vgpr by VMEM + // Return true if there might be pending writes to the vgpr-interval by VMEM // instructions with types different from V. - bool hasOtherPendingVmemTypes(int GprNo, VmemType V) const { - assert(GprNo < NUM_ALL_VGPRS); - return VgprVmemTypes[GprNo] & ~(1 << V); + bool hasOtherPendingVmemTypes(RegInterval Interval, VmemType V) const { + for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + assert(RegNo < NUM_ALL_VGPRS); + if (VgprVmemTypes[RegNo] & ~(1 << V)) + return true; + } + return false; } - void clearVgprVmemTypes(int GprNo) { - assert(GprNo < NUM_ALL_VGPRS); - VgprVmemTypes[GprNo] = 0; + void clearVgprVmemTypes(RegInterval Interval) { + for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + assert(RegNo < NUM_ALL_VGPRS); + VgprVmemTypes[RegNo] = 0; + } } void setStateOnFunctionEntryOrReturn() { @@ -396,19 +409,16 @@ class WaitcntBrackets { } void setRegScore(int GprNo, InstCounterType T, unsigned Val) { - if (GprNo < NUM_ALL_VGPRS) { - VgprUB = std::max(VgprUB, GprNo); - VgprScores[T][GprNo] = Val; - } else { - assert(T == SmemAccessCounter); - SgprUB = std::max(SgprUB, GprNo - NUM_ALL_VGPRS); - SgprScores[GprNo - NUM_ALL_VGPRS] = Val; - } + setScoreByInterval({GprNo, GprNo + 1}, T, Val); } - void setExpScore(const MachineInstr *MI, const SIRegisterInfo *TRI, - const MachineRegisterInfo *MRI, const MachineOperand &Op, - unsigned Val); + void setScoreByInterval(RegInterval Interval, InstCounterType CntTy, + unsigned Score); + + void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI, + const MachineRegisterInfo *MRI, + const MachineOperand &Op, InstCounterType CntTy, + unsigned Val); const GCNSubtarget *ST = nullptr; InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS; @@ -772,17 +782,30 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI, return Result; } -void WaitcntBrackets::setExpScore(const MachineInstr *MI, - const SIRegisterInfo *TRI, - const MachineRegisterInfo *MRI, - const MachineOperand &Op, unsigned Val) { - RegInterval Interval = getRegInterval(MI, MRI, TRI, Op); - assert(TRI->isVectorRegister(*MRI, Op.getReg())); +void WaitcntBrackets::setScoreByInterval(RegInterval Interval, + InstCounterType CntTy, + unsigned Score) { for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { - setRegScore(RegNo, EXP_CNT, Val); + if (RegNo < NUM_ALL_VGPRS) { + VgprUB = std::max(VgprUB, RegNo); + VgprScores[CntTy][RegNo] = Score; + } else { + assert(CntTy == SmemAccessCounter); + SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS); + SgprScores[RegNo - NUM_ALL_VGPRS] = Score; + } } } +void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI, + const SIRegisterInfo *TRI, + const MachineRegisterInfo *MRI, + const MachineOperand &Op, + InstCounterType CntTy, unsigned Score) { + RegInterval Interval = getRegInterval(MI, MRI, TRI, Op); + setScoreByInterval(Interval, CntTy, Score); +} + void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI, const MachineRegisterInfo *MRI, @@ -806,57 +829,61 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, // All GDS operations must protect their address register (same as // export.) if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr)) - setExpScore(&Inst, TRI, MRI, *AddrOp, CurrScore); + setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore); if (Inst.mayStore()) { if (const auto *Data0 = TII->getNamedOperand(Inst, AMDGPU::OpName::data0)) - setExpScore(&Inst, TRI, MRI, *Data0, CurrScore); + setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore); if (const auto *Data1 = TII->getNamedOperand(Inst, AMDGPU::OpName::data1)) - setExpScore(&Inst, TRI, MRI, *Data1, CurrScore); + setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) && Inst.getOpcode() != AMDGPU::DS_APPEND && Inst.getOpcode() != AMDGPU::DS_CONSUME && Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) { for (const MachineOperand &Op : Inst.all_uses()) { if (TRI->isVectorRegister(*MRI, Op.getReg())) - setExpScore(&Inst, TRI, MRI, Op, CurrScore); + setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore); } } } else if (TII->isFLAT(Inst)) { if (Inst.mayStore()) { - setExpScore(&Inst, TRI, MRI, - *TII->getNamedOperand(Inst, AMDGPU::OpName::data), - CurrScore); + setScoreByOperand(&Inst, TRI, MRI, + *TII->getNamedOperand(Inst, AMDGPU::OpName::data), + EXP_CNT, CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst)) { - setExpScore(&Inst, TRI, MRI, - *TII->getNamedOperand(Inst, AMDGPU::OpName::data), - CurrScore); + setScoreByOperand(&Inst, TRI, MRI, + *TII->getNamedOperand(Inst, AMDGPU::OpName::data), + EXP_CNT, CurrScore); } } else if (TII->isMIMG(Inst)) { if (Inst.mayStore()) { - setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore); + setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, + CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst)) { - setExpScore(&Inst, TRI, MRI, - *TII->getNamedOperand(Inst, AMDGPU::OpName::data), - CurrScore); + setScoreByOperand(&Inst, TRI, MRI, + *TII->getNamedOperand(Inst, AMDGPU::OpName::data), + EXP_CNT, CurrScore); } } else if (TII->isMTBUF(Inst)) { if (Inst.mayStore()) - setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore); + setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, + CurrScore); } else if (TII->isMUBUF(Inst)) { if (Inst.mayStore()) { - setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore); + setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT, + CurrScore); } else if (SIInstrInfo::isAtomicRet(Inst)) { - setExpScore(&Inst, TRI, MRI, - *TII->getNamedOperand(Inst, AMDGPU::OpName::data), - CurrScore); + setScoreByOperand(&Inst, TRI, MRI, + *TII->getNamedOperand(Inst, AMDGPU::OpName::data), + EXP_CNT, CurrScore); } } else if (TII->isLDSDIR(Inst)) { // LDSDIR instructions attach the score to the destination. - setExpScore(&Inst, TRI, MRI, - *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), CurrScore); + setScoreByOperand(&Inst, TRI, MRI, + *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), + EXP_CNT, CurrScore); } else { if (TII->isEXP(Inst)) { // For export the destination registers are really temps that @@ -865,15 +892,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, // score. for (MachineOperand &DefMO : Inst.all_defs()) { if (TRI->isVGPR(*MRI, DefMO.getReg())) { - setRegScore( - TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)), - EXP_CNT, CurrScore); + setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore); } } } for (const MachineOperand &Op : Inst.all_uses()) { if (TRI->isVectorRegister(*MRI, Op.getReg())) - setExpScore(&Inst, TRI, MRI, Op, CurrScore); + setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore); } } } else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ { @@ -901,9 +926,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII, VgprVmemTypes[RegNo] |= 1 << V; } } - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { - setRegScore(RegNo, T, CurrScore); - } + setScoreByInterval(Interval, T, CurrScore); } if (Inst.mayStore() && (TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) { @@ -1034,31 +1057,34 @@ void WaitcntBrackets::simplifyWaitcnt(InstCounterType T, Count = ~0u; } -void WaitcntBrackets::determineWait(InstCounterType T, int RegNo, +void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval, AMDGPU::Waitcnt &Wait) const { - unsigned ScoreToWait = getRegScore(RegNo, T); - - // If the score of src_operand falls within the bracket, we need an - // s_waitcnt instruction. const unsigned LB = getScoreLB(T); const unsigned UB = getScoreUB(T); - if ((UB >= ScoreToWait) && (ScoreToWait > LB)) { - if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() && - !ST->hasFlatLgkmVMemCountInOrder()) { - // If there is a pending FLAT operation, and this is a VMem or LGKM - // waitcnt and the target can report early completion, then we need - // to force a waitcnt 0. - addWait(Wait, T, 0); - } else if (counterOutOfOrder(T)) { - // Counter can get decremented out-of-order when there - // are multiple types event in the bracket. Also emit an s_wait counter - // with a conservative value of 0 for the counter. - addWait(Wait, T, 0); - } else { - // If a counter has been maxed out avoid overflow by waiting for - // MAX(CounterType) - 1 instead. - unsigned NeededWait = std::min(UB - ScoreToWait, getWaitCountMax(T) - 1); - addWait(Wait, T, NeededWait); + for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + unsigned ScoreToWait = getRegScore(RegNo, T); + + // If the score of src_operand falls within the bracket, we need an + // s_waitcnt instruction. + if ((UB >= ScoreToWait) && (ScoreToWait > LB)) { + if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() && + !ST->hasFlatLgkmVMemCountInOrder()) { + // If there is a pending FLAT operation, and this is a VMem or LGKM + // waitcnt and the target can report early completion, then we need + // to force a waitcnt 0. + addWait(Wait, T, 0); + } else if (counterOutOfOrder(T)) { + // Counter can get decremented out-of-order when there + // are multiple types event in the bracket. Also emit an s_wait counter + // with a conservative value of 0 for the counter. + addWait(Wait, T, 0); + } else { + // If a counter has been maxed out avoid overflow by waiting for + // MAX(CounterType) - 1 instead. + unsigned NeededWait = + std::min(UB - ScoreToWait, getWaitCountMax(T) - 1); + addWait(Wait, T, NeededWait); + } } } } @@ -1670,18 +1696,16 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, RegInterval CallAddrOpInterval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp); - for (int RegNo = CallAddrOpInterval.first; - RegNo < CallAddrOpInterval.second; ++RegNo) - ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait); + ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval, + Wait); if (const auto *RtnAddrOp = TII->getNamedOperand(MI, AMDGPU::OpName::dst)) { RegInterval RtnAddrOpInterval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp); - for (int RegNo = RtnAddrOpInterval.first; - RegNo < RtnAddrOpInterval.second; ++RegNo) - ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait); + ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval, + Wait); } } } else { @@ -1750,36 +1774,34 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI, RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op); const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg()); - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { - if (IsVGPR) { - // Implicit VGPR defs and uses are never a part of the memory - // instructions description and usually present to account for - // super-register liveness. - // TODO: Most of the other instructions also have implicit uses - // for the liveness accounting only. - if (Op.isImplicit() && MI.mayLoadOrStore()) - continue; - - // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the - // previous write and this write are the same type of VMEM - // instruction, in which case they are (in some architectures) - // guaranteed to write their results in order anyway. - if (Op.isUse() || !updateVMCntOnly(MI) || - ScoreBrackets.hasOtherPendingVmemTypes(RegNo, - getVmemType(MI)) || - !ST->hasVmemWriteVgprInOrder()) { - ScoreBrackets.determineWait(LOAD_CNT, RegNo, Wait); - ScoreBrackets.determineWait(SAMPLE_CNT, RegNo, Wait); - ScoreBrackets.determineWait(BVH_CNT, RegNo, Wait); - ScoreBrackets.clearVgprVmemTypes(RegNo); - } - if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) { - ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait); - } - ScoreBrackets.determineWait(DS_CNT, RegNo, Wait); - } else { - ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait); + if (IsVGPR) { + // Implicit VGPR defs and uses are never a part of the memory + // instructions description and usually present to account for + // super-register liveness. + // TODO: Most of the other instructions also have implicit uses + // for the liveness accounting only. + if (Op.isImplicit() && MI.mayLoadOrStore()) + continue; + + // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the + // previous write and this write are the same type of VMEM + // instruction, in which case they are (in some architectures) + // guaranteed to write their results in order anyway. + if (Op.isUse() || !updateVMCntOnly(MI) || + ScoreBrackets.hasOtherPendingVmemTypes(Interval, + getVmemType(MI)) || + !ST->hasVmemWriteVgprInOrder()) { + ScoreBrackets.determineWait(LOAD_CNT, Interval, Wait); + ScoreBrackets.determineWait(SAMPLE_CNT, Interval, Wait); + ScoreBrackets.determineWait(BVH_CNT, Interval, Wait); + ScoreBrackets.clearVgprVmemTypes(Interval); + } + if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) { + ScoreBrackets.determineWait(EXP_CNT, Interval, Wait); } + ScoreBrackets.determineWait(DS_CNT, Interval, Wait); + } else { + ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait); } } }