Skip to content

[CodeGen][NewPM] Port machine-block-freq to new pass manager #98317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion llvm/include/llvm/CodeGen/LazyMachineBlockFrequencyInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class LazyMachineBlockFrequencyInfoPass : public MachineFunctionPass {

bool runOnMachineFunction(MachineFunction &F) override;
void releaseMemory() override;
void print(raw_ostream &OS, const Module *M) const override;
};
}
#endif
63 changes: 54 additions & 9 deletions llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define LLVM_CODEGEN_MACHINEBLOCKFREQUENCYINFO_H

#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachinePassManager.h"
#include "llvm/Support/BlockFrequency.h"
#include <cstdint>
#include <memory>
Expand All @@ -30,29 +31,30 @@ class raw_ostream;

/// MachineBlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation
/// to estimate machine basic block frequencies.
class MachineBlockFrequencyInfo : public MachineFunctionPass {
class MachineBlockFrequencyInfo {
using ImplType = BlockFrequencyInfoImpl<MachineBasicBlock>;
std::unique_ptr<ImplType> MBFI;

public:
static char ID;

MachineBlockFrequencyInfo();
MachineBlockFrequencyInfo(); // Legacy pass manager only.
explicit MachineBlockFrequencyInfo(MachineFunction &F,
MachineBranchProbabilityInfo &MBPI,
MachineLoopInfo &MLI);
~MachineBlockFrequencyInfo() override;

void getAnalysisUsage(AnalysisUsage &AU) const override;
MachineBlockFrequencyInfo(MachineBlockFrequencyInfo &&);
~MachineBlockFrequencyInfo();

bool runOnMachineFunction(MachineFunction &F) override;
/// Handle invalidation explicitly.
bool invalidate(MachineFunction &F, const PreservedAnalyses &PA,
MachineFunctionAnalysisManager::Invalidator &);

/// calculate - compute block frequency info for the given function.
void calculate(const MachineFunction &F,
const MachineBranchProbabilityInfo &MBPI,
const MachineLoopInfo &MLI);

void releaseMemory() override;
void print(raw_ostream &OS);

void releaseMemory();

/// getblockFreq - Return block frequency. Return 0 if we don't have the
/// information. Please note that initial frequency is equal to 1024. It means
Expand Down Expand Up @@ -107,6 +109,49 @@ Printable printBlockFreq(const MachineBlockFrequencyInfo &MBFI,
Printable printBlockFreq(const MachineBlockFrequencyInfo &MBFI,
const MachineBasicBlock &MBB);

class MachineBlockFrequencyAnalysis
: public AnalysisInfoMixin<MachineBlockFrequencyAnalysis> {
friend AnalysisInfoMixin<MachineBlockFrequencyAnalysis>;
static AnalysisKey Key;

public:
using Result = MachineBlockFrequencyInfo;

Result run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM);
};

/// Printer pass for the \c MachineBlockFrequencyInfo results.
class MachineBlockFrequencyPrinterPass
: public PassInfoMixin<MachineBlockFrequencyPrinterPass> {
raw_ostream &OS;

public:
explicit MachineBlockFrequencyPrinterPass(raw_ostream &OS) : OS(OS) {}

PreservedAnalyses run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM);

static bool isRequired() { return true; }
};

class MachineBlockFrequencyInfoWrapperPass : public MachineFunctionPass {
MachineBlockFrequencyInfo MBFI;

public:
static char ID;

MachineBlockFrequencyInfoWrapperPass();

void getAnalysisUsage(AnalysisUsage &AU) const override;

bool runOnMachineFunction(MachineFunction &F) override;

void releaseMemory() override { MBFI.releaseMemory(); }

MachineBlockFrequencyInfo &getMBFI() { return MBFI; }

const MachineBlockFrequencyInfo &getMBFI() const { return MBFI; }
};
} // end namespace llvm

#endif // LLVM_CODEGEN_MACHINEBLOCKFREQUENCYINFO_H
2 changes: 1 addition & 1 deletion llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void initializeMIRAddFSDiscriminatorsPass(PassRegistry &);
void initializeMIRCanonicalizerPass(PassRegistry &);
void initializeMIRNamerPass(PassRegistry &);
void initializeMIRPrintingPassPass(PassRegistry&);
void initializeMachineBlockFrequencyInfoPass(PassRegistry&);
void initializeMachineBlockFrequencyInfoWrapperPassPass(PassRegistry &);
void initializeMachineBlockPlacementPass(PassRegistry&);
void initializeMachineBlockPlacementStatsPass(PassRegistry&);
void initializeMachineBranchProbabilityInfoWrapperPassPass(PassRegistry &);
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/Passes/MachinePassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ LOOP_PASS("loop-reduce", LoopStrengthReducePass())
// preferably fix the scavenger to not depend on them).
MACHINE_FUNCTION_ANALYSIS("live-intervals", LiveIntervalsAnalysis())
MACHINE_FUNCTION_ANALYSIS("live-vars", LiveVariablesAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-block-freq", MachineBlockFrequencyAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-branch-prob",
MachineBranchProbabilityAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-dom-tree", MachineDominatorTreeAnalysis())
Expand All @@ -108,7 +109,6 @@ MACHINE_FUNCTION_ANALYSIS("slot-indexes", SlotIndexesAnalysis())
// MACHINE_FUNCTION_ANALYSIS("edge-bundles", EdgeBundlesAnalysis())
// MACHINE_FUNCTION_ANALYSIS("lazy-machine-bfi",
// LazyMachineBlockFrequencyInfoAnalysis())
// MACHINE_FUNCTION_ANALYSIS("machine-bfi", MachineBlockFrequencyInfoAnalysis())
// MACHINE_FUNCTION_ANALYSIS("machine-loops", MachineLoopInfoAnalysis())
// MACHINE_FUNCTION_ANALYSIS("machine-dom-frontier",
// MachineDominanceFrontierAnalysis())
Expand All @@ -135,6 +135,8 @@ MACHINE_FUNCTION_PASS("no-op-machine-function", NoOpMachineFunctionPass())
MACHINE_FUNCTION_PASS("print", PrintMIRPass())
MACHINE_FUNCTION_PASS("print<live-intervals>", LiveIntervalsPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<live-vars>", LiveVariablesPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-block-freq>",
MachineBlockFrequencyPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-branch-prob>",
MachineBranchProbabilityPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-dom-tree>",
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/BranchFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ namespace {
bool runOnMachineFunction(MachineFunction &MF) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
AU.addRequired<TargetPassConfig>();
Expand Down Expand Up @@ -130,7 +130,7 @@ bool BranchFolderPass::runOnMachineFunction(MachineFunction &MF) {
bool EnableTailMerge = !MF.getTarget().requiresStructuredCFG() &&
PassConfig->getEnableTailMerge();
MBFIWrapper MBBFreqInfo(
getAnalysis<MachineBlockFrequencyInfo>());
getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI());
BranchFolder Folder(
EnableTailMerge, /*CommonHoist=*/true, MBBFreqInfo,
getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI(),
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
initializeMIRCanonicalizerPass(Registry);
initializeMIRNamerPass(Registry);
initializeMIRProfileLoaderPassPass(Registry);
initializeMachineBlockFrequencyInfoPass(Registry);
initializeMachineBlockFrequencyInfoWrapperPassPass(Registry);
initializeMachineBlockPlacementPass(Registry);
initializeMachineBlockPlacementStatsPass(Registry);
initializeMachineCFGPrinterPass(Registry);
Expand Down
29 changes: 15 additions & 14 deletions llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ char RegBankSelect::ID = 0;
INITIALIZE_PASS_BEGIN(RegBankSelect, DEBUG_TYPE,
"Assign register bank of generic virtual registers",
false, false);
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE,
Expand All @@ -85,7 +85,7 @@ void RegBankSelect::init(MachineFunction &MF) {
TRI = MF.getSubtarget().getRegisterInfo();
TPC = &getAnalysis<TargetPassConfig>();
if (OptMode != Mode::Fast) {
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
} else {
MBFI = nullptr;
Expand All @@ -99,7 +99,7 @@ void RegBankSelect::getAnalysisUsage(AnalysisUsage &AU) const {
if (OptMode != Mode::Fast) {
// We could preserve the information from these two analysis but
// the APIs do not allow to do so yet.
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
}
AU.addRequired<TargetPassConfig>();
Expand Down Expand Up @@ -919,19 +919,19 @@ bool RegBankSelect::InstrInsertPoint::isSplit() const {
uint64_t RegBankSelect::InstrInsertPoint::frequency(const Pass &P) const {
// Even if we need to split, because we insert between terminators,
// this split has actually the same frequency as the instruction.
const MachineBlockFrequencyInfo *MBFI =
P.getAnalysisIfAvailable<MachineBlockFrequencyInfo>();
if (!MBFI)
const auto *MBFIWrapper =
P.getAnalysisIfAvailable<MachineBlockFrequencyInfoWrapperPass>();
if (!MBFIWrapper)
return 1;
return MBFI->getBlockFreq(Instr.getParent()).getFrequency();
return MBFIWrapper->getMBFI().getBlockFreq(Instr.getParent()).getFrequency();
}

uint64_t RegBankSelect::MBBInsertPoint::frequency(const Pass &P) const {
const MachineBlockFrequencyInfo *MBFI =
P.getAnalysisIfAvailable<MachineBlockFrequencyInfo>();
if (!MBFI)
const auto *MBFIWrapper =
P.getAnalysisIfAvailable<MachineBlockFrequencyInfoWrapperPass>();
if (!MBFIWrapper)
return 1;
return MBFI->getBlockFreq(&MBB).getFrequency();
return MBFIWrapper->getMBFI().getBlockFreq(&MBB).getFrequency();
}

void RegBankSelect::EdgeInsertPoint::materialize() {
Expand All @@ -948,10 +948,11 @@ void RegBankSelect::EdgeInsertPoint::materialize() {
}

uint64_t RegBankSelect::EdgeInsertPoint::frequency(const Pass &P) const {
const MachineBlockFrequencyInfo *MBFI =
P.getAnalysisIfAvailable<MachineBlockFrequencyInfo>();
if (!MBFI)
const auto *MBFIWrapper =
P.getAnalysisIfAvailable<MachineBlockFrequencyInfoWrapperPass>();
if (!MBFIWrapper)
return 1;
const auto *MBFI = &MBFIWrapper->getMBFI();
if (WasMaterialized)
return MBFI->getBlockFreq(DstOrSplit).getFrequency();

Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/CodeGen/IfConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ namespace {
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
Expand Down Expand Up @@ -444,7 +444,8 @@ bool IfConverter::runOnMachineFunction(MachineFunction &MF) {
TLI = ST.getTargetLowering();
TII = ST.getInstrInfo();
TRI = ST.getRegisterInfo();
MBFIWrapper MBFI(getAnalysis<MachineBlockFrequencyInfo>());
MBFIWrapper MBFI(
getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI());
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
ProfileSummaryInfo *PSI =
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/InlineSpiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class HoistSpillHelper : private LiveRangeEdit::Delegate {
MDT(pass.getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree()),
VRM(vrm), MRI(mf.getRegInfo()), TII(*mf.getSubtarget().getInstrInfo()),
TRI(*mf.getSubtarget().getRegisterInfo()),
MBFI(pass.getAnalysis<MachineBlockFrequencyInfo>()),
MBFI(
pass.getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI()),
IPA(LIS, mf.getNumBlockIDs()) {}

void addToMergeableSpills(MachineInstr &Spill, int StackSlot,
Expand Down Expand Up @@ -193,7 +194,8 @@ class InlineSpiller : public Spiller {
MDT(Pass.getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree()),
VRM(VRM), MRI(MF.getRegInfo()), TII(*MF.getSubtarget().getInstrInfo()),
TRI(*MF.getSubtarget().getRegisterInfo()),
MBFI(Pass.getAnalysis<MachineBlockFrequencyInfo>()),
MBFI(
Pass.getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI()),
HSpiller(Pass, MF, VRM), VRAI(VRAI) {}

void spill(LiveRangeEdit &) override;
Expand Down
12 changes: 4 additions & 8 deletions llvm/lib/CodeGen/LazyMachineBlockFrequencyInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ LazyMachineBlockFrequencyInfoPass::LazyMachineBlockFrequencyInfoPass()
*PassRegistry::getPassRegistry());
}

void LazyMachineBlockFrequencyInfoPass::print(raw_ostream &OS,
const Module *M) const {
getBFI().print(OS, M);
}

void LazyMachineBlockFrequencyInfoPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
Expand All @@ -56,10 +51,11 @@ void LazyMachineBlockFrequencyInfoPass::releaseMemory() {

MachineBlockFrequencyInfo &
LazyMachineBlockFrequencyInfoPass::calculateIfNotAvailable() const {
auto *MBFI = getAnalysisIfAvailable<MachineBlockFrequencyInfo>();
if (MBFI) {
auto *MBFIWrapper =
getAnalysisIfAvailable<MachineBlockFrequencyInfoWrapperPass>();
if (MBFIWrapper) {
LLVM_DEBUG(dbgs() << "MachineBlockFrequencyInfo is available\n");
return *MBFI;
return MBFIWrapper->getMBFI();
}

auto &MBPI = getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/MIRSampleProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ char MIRProfileLoaderPass::ID = 0;
INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
"Load MIR Sample Profile",
/* cfg = */ false, /* is_analysis = */ false)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
Expand Down Expand Up @@ -363,7 +363,7 @@ bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {

LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
<< MF.getFunction().getName() << "\n");
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
MIRSampleLoader->setInitVals(
&getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(),
&getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree(),
Expand Down Expand Up @@ -401,7 +401,7 @@ bool MIRProfileLoaderPass::doInitialization(Module &M) {

void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
AU.addRequiredTransitive<MachineLoopInfoWrapperPass>();
Expand Down
15 changes: 9 additions & 6 deletions llvm/lib/CodeGen/MLRegAllocEvictAdvisor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class RegAllocScoring : public MachineFunctionPass {
AU.setPreservesAll();
AU.addRequired<RegAllocEvictionAdvisorAnalysis>();
AU.addRequired<RegAllocPriorityAdvisorAnalysis>();
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}

Expand Down Expand Up @@ -388,7 +388,7 @@ class ReleaseModeEvictionAdvisorAnalysis final
std::vector<TensorSpec> InputFeatures;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
AU.addRequired<MachineLoopInfoWrapperPass>();
RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
}
Expand All @@ -406,7 +406,8 @@ class ReleaseModeEvictionAdvisorAnalysis final
InteractiveChannelBaseName + ".in");
}
return std::make_unique<MLEvictAdvisor>(
MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
MF, RA, Runner.get(),
getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI(),
getAnalysis<MachineLoopInfoWrapperPass>().getLI());
}
std::unique_ptr<MLModelRunner> Runner;
Expand Down Expand Up @@ -495,7 +496,7 @@ class DevelopmentModeEvictionAdvisorAnalysis final
std::vector<TensorSpec> TrainingInputFeatures;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
AU.addRequired<MachineLoopInfoWrapperPass>();
RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
}
Expand Down Expand Up @@ -544,7 +545,8 @@ class DevelopmentModeEvictionAdvisorAnalysis final
if (Log)
Log->switchContext(MF.getName());
return std::make_unique<DevelopmentModeEvictAdvisor>(
MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
MF, RA, Runner.get(),
getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI(),
getAnalysis<MachineLoopInfoWrapperPass>().getLI(), Log.get());
}

Expand Down Expand Up @@ -1139,7 +1141,8 @@ bool RegAllocScoring::runOnMachineFunction(MachineFunction &MF) {
auto GetReward = [&]() {
if (!CachedReward)
CachedReward = static_cast<float>(
calculateRegAllocScore(MF, getAnalysis<MachineBlockFrequencyInfo>())
calculateRegAllocScore(
MF, getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI())
.getScore());
return *CachedReward;
};
Expand Down
Loading
Loading