Skip to content

Reapply "[CodeGen][NewPM] Port machine-branch-prob to new pass manager" (#96858) #96869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions llvm/include/llvm/CodeGen/MachineBranchProbabilityInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
#define LLVM_CODEGEN_MACHINEBRANCHPROBABILITYINFO_H

#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachinePassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/BranchProbability.h"

namespace llvm {

class MachineBranchProbabilityInfo : public ImmutablePass {
virtual void anchor();

class MachineBranchProbabilityInfo {
// Default weight value. Used when we don't have information about the edge.
// TODO: DEFAULT_WEIGHT makes sense during static predication, when none of
// the successors have a weight yet. But it doesn't make sense when providing
Expand All @@ -31,13 +30,8 @@ class MachineBranchProbabilityInfo : public ImmutablePass {
static const uint32_t DEFAULT_WEIGHT = 16;

public:
static char ID;

MachineBranchProbabilityInfo();

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
}
bool invalidate(MachineFunction &, const PreservedAnalyses &PA,
MachineFunctionAnalysisManager::Invalidator &);

// Return edge probability.
BranchProbability getEdgeProbability(const MachineBasicBlock *Src,
Expand All @@ -61,6 +55,45 @@ class MachineBranchProbabilityInfo : public ImmutablePass {
const MachineBasicBlock *Dst) const;
};

class MachineBranchProbabilityAnalysis
: public AnalysisInfoMixin<MachineBranchProbabilityAnalysis> {
friend AnalysisInfoMixin<MachineBranchProbabilityAnalysis>;

static AnalysisKey Key;

public:
using Result = MachineBranchProbabilityInfo;

Result run(MachineFunction &, MachineFunctionAnalysisManager &);
};

class MachineBranchProbabilityPrinterPass
: public PassInfoMixin<MachineBranchProbabilityPrinterPass> {
raw_ostream &OS;

public:
MachineBranchProbabilityPrinterPass(raw_ostream &OS) : OS(OS) {}
PreservedAnalyses run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM);
};

class MachineBranchProbabilityInfoWrapperPass : public ImmutablePass {
virtual void anchor();

MachineBranchProbabilityInfo MBPI;

public:
static char ID;

MachineBranchProbabilityInfoWrapperPass();

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
}

MachineBranchProbabilityInfo &getMBPI() { return MBPI; }
const MachineBranchProbabilityInfo &getMBPI() const { return MBPI; }
};
}


Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void initializeMIRPrintingPassPass(PassRegistry&);
void initializeMachineBlockFrequencyInfoPass(PassRegistry&);
void initializeMachineBlockPlacementPass(PassRegistry&);
void initializeMachineBlockPlacementStatsPass(PassRegistry&);
void initializeMachineBranchProbabilityInfoPass(PassRegistry&);
void initializeMachineBranchProbabilityInfoWrapperPassPass(PassRegistry &);
void initializeMachineCFGPrinterPass(PassRegistry &);
void initializeMachineCSEPass(PassRegistry&);
void initializeMachineCombinerPass(PassRegistry&);
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Passes/MachinePassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ LOOP_PASS("loop-reduce", LoopStrengthReducePass())
#ifndef MACHINE_FUNCTION_ANALYSIS
#define MACHINE_FUNCTION_ANALYSIS(NAME, CREATE_PASS)
#endif
MACHINE_FUNCTION_ANALYSIS("machine-branch-prob",
MachineBranchProbabilityAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-dom-tree", MachineDominatorTreeAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-post-dom-tree",
MachinePostDominatorTreeAnalysis())
Expand Down Expand Up @@ -130,6 +132,8 @@ MACHINE_FUNCTION_PASS("finalize-isel", FinalizeISelPass())
MACHINE_FUNCTION_PASS("localstackalloc", LocalStackSlotAllocationPass())
MACHINE_FUNCTION_PASS("no-op-machine-function", NoOpMachineFunctionPass())
MACHINE_FUNCTION_PASS("print", PrintMIRPass())
MACHINE_FUNCTION_PASS("print<machine-branch-prob>",
MachineBranchProbabilityPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-dom-tree>",
MachineDominatorTreePrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-post-dom-tree>",
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineOptimizationRemarkEmitterPass>();
AU.addRequired<GCModuleInfo>();
AU.addRequired<LazyMachineBlockFrequencyInfoPass>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
}

bool AsmPrinter::doInitialization(Module &M) {
Expand Down Expand Up @@ -1478,8 +1478,9 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
? &getAnalysis<LazyMachineBlockFrequencyInfoPass>().getBFI()
: nullptr;
const MachineBranchProbabilityInfo *MBPI =
Features.BrProb ? &getAnalysis<MachineBranchProbabilityInfo>()
: nullptr;
Features.BrProb
? &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI()
: nullptr;

if (Features.BBFreq || Features.BrProb) {
for (const MachineBasicBlock &MBB : MF) {
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/CodeGen/BranchFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ namespace {

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
AU.addRequired<TargetPassConfig>();
MachineFunctionPass::getAnalysisUsage(AU);
Expand Down Expand Up @@ -131,9 +131,10 @@ bool BranchFolderPass::runOnMachineFunction(MachineFunction &MF) {
PassConfig->getEnableTailMerge();
MBFIWrapper MBBFreqInfo(
getAnalysis<MachineBlockFrequencyInfo>());
BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true, MBBFreqInfo,
getAnalysis<MachineBranchProbabilityInfo>(),
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI());
BranchFolder Folder(
EnableTailMerge, /*CommonHoist=*/true, MBBFreqInfo,
getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI(),
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI());
return Folder.OptimizeFunction(MF, MF.getSubtarget().getInstrInfo(),
MF.getSubtarget().getRegisterInfo());
}
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/CodeGen/EarlyIfConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,14 @@ char &llvm::EarlyIfConverterID = EarlyIfConverter::ID;

INITIALIZE_PASS_BEGIN(EarlyIfConverter, DEBUG_TYPE,
"Early If Converter", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineTraceMetrics)
INITIALIZE_PASS_END(EarlyIfConverter, DEBUG_TYPE,
"Early If Converter", false, false)

void EarlyIfConverter::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addPreserved<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfo>();
Expand Down Expand Up @@ -1142,12 +1142,12 @@ char &llvm::EarlyIfPredicatorID = EarlyIfPredicator::ID;
INITIALIZE_PASS_BEGIN(EarlyIfPredicator, DEBUG_TYPE, "Early If Predicator",
false, false)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_END(EarlyIfPredicator, DEBUG_TYPE, "Early If Predicator", false,
false)

void EarlyIfPredicator::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addPreserved<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfo>();
Expand Down Expand Up @@ -1222,7 +1222,7 @@ bool EarlyIfPredicator::runOnMachineFunction(MachineFunction &MF) {
SchedModel.init(&STI);
DomTree = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
Loops = &getAnalysis<MachineLoopInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();

bool Changed = false;
IfConv.runOnMachineFunction(MF);
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ INITIALIZE_PASS_BEGIN(RegBankSelect, DEBUG_TYPE,
"Assign register bank of generic virtual registers",
false, false);
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE,
"Assign register bank of generic virtual registers", false,
Expand All @@ -86,7 +86,7 @@ void RegBankSelect::init(MachineFunction &MF) {
TPC = &getAnalysis<TargetPassConfig>();
if (OptMode != Mode::Fast) {
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
} else {
MBFI = nullptr;
MBPI = nullptr;
Expand All @@ -100,7 +100,7 @@ void RegBankSelect::getAnalysisUsage(AnalysisUsage &AU) const {
// We could preserve the information from these two analysis but
// the APIs do not allow to do so yet.
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
}
AU.addRequired<TargetPassConfig>();
getSelectionDAGFallbackAnalysisUsage(AU);
Expand Down Expand Up @@ -955,8 +955,10 @@ uint64_t RegBankSelect::EdgeInsertPoint::frequency(const Pass &P) const {
if (WasMaterialized)
return MBFI->getBlockFreq(DstOrSplit).getFrequency();

auto *MBPIWrapper =
P.getAnalysisIfAvailable<MachineBranchProbabilityInfoWrapperPass>();
const MachineBranchProbabilityInfo *MBPI =
P.getAnalysisIfAvailable<MachineBranchProbabilityInfo>();
MBPIWrapper ? &MBPIWrapper->getMBPI() : nullptr;
if (!MBPI)
return 1;
// The basic block will be on the edge.
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/IfConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ namespace {

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand Down Expand Up @@ -432,7 +432,7 @@ char IfConverter::ID = 0;
char &llvm::IfConverterID = IfConverter::ID;

INITIALIZE_PASS_BEGIN(IfConverter, DEBUG_TYPE, "If Converter", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
INITIALIZE_PASS_END(IfConverter, DEBUG_TYPE, "If Converter", false, false)

Expand All @@ -445,7 +445,7 @@ bool IfConverter::runOnMachineFunction(MachineFunction &MF) {
TII = ST.getInstrInfo();
TRI = ST.getRegisterInfo();
MBFIWrapper MBFI(getAnalysis<MachineBlockFrequencyInfo>());
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
ProfileSummaryInfo *PSI =
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
MRI = &MF.getRegInfo();
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/LazyMachineBlockFrequencyInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace llvm;

INITIALIZE_PASS_BEGIN(LazyMachineBlockFrequencyInfoPass, DEBUG_TYPE,
"Lazy Machine Block Frequency Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_END(LazyMachineBlockFrequencyInfoPass, DEBUG_TYPE,
"Lazy Machine Block Frequency Analysis", true, true)
Expand All @@ -43,7 +43,7 @@ void LazyMachineBlockFrequencyInfoPass::print(raw_ostream &OS,

void LazyMachineBlockFrequencyInfoPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand All @@ -62,7 +62,7 @@ LazyMachineBlockFrequencyInfoPass::calculateIfNotAvailable() const {
return *MBFI;
}

auto &MBPI = getAnalysis<MachineBranchProbabilityInfo>();
auto &MBPI = getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
auto *MLI = getAnalysisIfAvailable<MachineLoopInfo>();
auto *MDTWrapper = getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
auto *MDT = MDTWrapper ? &MDTWrapper->getDomTree() : nullptr;
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct DOTGraphTraits<MachineBlockFrequencyInfo *>

INITIALIZE_PASS_BEGIN(MachineBlockFrequencyInfo, DEBUG_TYPE,
"Machine Block Frequency Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_END(MachineBlockFrequencyInfo, DEBUG_TYPE,
"Machine Block Frequency Analysis", true, true)
Expand All @@ -185,7 +185,7 @@ MachineBlockFrequencyInfo::MachineBlockFrequencyInfo(
MachineBlockFrequencyInfo::~MachineBlockFrequencyInfo() = default;

void MachineBlockFrequencyInfo::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineLoopInfo>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
Expand All @@ -209,7 +209,7 @@ void MachineBlockFrequencyInfo::calculate(

bool MachineBlockFrequencyInfo::runOnMachineFunction(MachineFunction &F) {
MachineBranchProbabilityInfo &MBPI =
getAnalysis<MachineBranchProbabilityInfo>();
getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
MachineLoopInfo &MLI = getAnalysis<MachineLoopInfo>();
calculate(F, MBPI, MLI);
return false;
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/CodeGen/MachineBlockPlacement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ class MachineBlockPlacement : public MachineFunctionPass {
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineBlockFrequencyInfo>();
if (TailDupPlacement)
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
Expand All @@ -627,7 +627,7 @@ char &llvm::MachineBlockPlacementID = MachineBlockPlacement::ID;

INITIALIZE_PASS_BEGIN(MachineBlockPlacement, DEBUG_TYPE,
"Branch Probability Basic Block Placement", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
Expand Down Expand Up @@ -3425,7 +3425,7 @@ bool MachineBlockPlacement::runOnMachineFunction(MachineFunction &MF) {
return false;

F = &MF;
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
MBFI = std::make_unique<MBFIWrapper>(
getAnalysis<MachineBlockFrequencyInfo>());
MLI = &getAnalysis<MachineLoopInfo>();
Expand Down Expand Up @@ -3726,7 +3726,7 @@ class MachineBlockPlacementStats : public MachineFunctionPass {
bool runOnMachineFunction(MachineFunction &F) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineBlockFrequencyInfo>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
Expand All @@ -3741,7 +3741,7 @@ char &llvm::MachineBlockPlacementStatsID = MachineBlockPlacementStats::ID;

INITIALIZE_PASS_BEGIN(MachineBlockPlacementStats, "block-placement-stats",
"Basic Block Placement Stats", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_END(MachineBlockPlacementStats, "block-placement-stats",
"Basic Block Placement Stats", false, false)
Expand All @@ -3754,7 +3754,7 @@ bool MachineBlockPlacementStats::runOnMachineFunction(MachineFunction &F) {
if (!isFunctionInPrintList(F.getName()))
return false;

MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();

for (MachineBasicBlock &MBB : F) {
Expand Down
Loading
Loading