Skip to content

Reland "[CodeGen] Support start/stop in CodeGenPassBuilder (#70912)" #78570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 113 additions & 20 deletions llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
#include "llvm/CodeGen/StackProtector.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/UnreachableBlockElim.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
Expand Down Expand Up @@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Function object to maintain state while adding codegen IR passes.
class AddIRPass {
public:
AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
~AddIRPass() {
if (!FPM.isEmpty())
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}

template <typename PassT> void operator()(PassT &&Pass) {
template <typename PassT>
void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
static_assert((is_detected<is_function_pass_t, PassT>::value ||
is_detected<is_module_pass_t, PassT>::value) &&
"Only module pass and function pass are supported.");

if (!PB.runBeforeAdding(Name))
return;

// Add Function Pass
if constexpr (is_detected<is_function_pass_t, PassT>::value) {
FPM.addPass(std::forward<PassT>(Pass));

for (auto &C : PB.AfterCallbacks)
C(Name);
} else {
// Add Module Pass
if (!FPM.isEmpty()) {
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
FPM = FunctionPassManager();
}

MPM.addPass(std::forward<PassT>(Pass));

for (auto &C : PB.AfterCallbacks)
C(Name);
}
}

private:
ModulePassManager &MPM;
FunctionPassManager FPM;
const DerivedT &PB;
};

// Function object to maintain state while adding codegen machine passes.
class AddMachinePass {
public:
AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
: PM(PM), PB(PB) {}

template <typename PassT> void operator()(PassT &&Pass) {
static_assert(
is_detected<has_key_t, PassT>::value,
"Machine function pass must define a static member variable `Key`.");
for (auto &C : BeforeCallbacks)
if (!C(&PassT::Key))
return;

if (!PB.runBeforeAdding(PassT::name()))
return;

PM.addPass(std::forward<PassT>(Pass));
for (auto &C : AfterCallbacks)
C(&PassT::Key);

for (auto &C : PB.AfterCallbacks)
C(PassT::name());
}

template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
AfterCallbacks.emplace_back(
PB.AfterCallbacks.emplace_back(
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
if (PassID == ID)
this->PM.addPass(std::move(Pass));
});
}

void disablePass(MachinePassKey *ID) {
BeforeCallbacks.emplace_back(
[ID](MachinePassKey *PassID) { return PassID != ID; });
}

MachineFunctionPassManager releasePM() { return std::move(PM); }

private:
MachineFunctionPassManager &PM;
SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
BeforeCallbacks;
SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
AfterCallbacks;
const DerivedT &PB;
};

LLVMTargetMachine &TM;
Expand Down Expand Up @@ -473,20 +481,43 @@ template <typename DerivedT> class CodeGenPassBuilder {
const DerivedT &derived() const {
return static_cast<const DerivedT &>(*this);
}

bool runBeforeAdding(StringRef Name) const {
bool ShouldAdd = true;
for (auto &C : BeforeCallbacks)
ShouldAdd &= C(Name);
return ShouldAdd;
}

void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;

Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;

mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
BeforeCallbacks;
mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;

/// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
mutable bool Started = true;
mutable bool Stopped = true;
};

template <typename Derived>
Error CodeGenPassBuilder<Derived>::buildPipeline(
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
CodeGenFileType FileType) const {
AddIRPass addIRPass(MPM);
auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
if (!StartStopInfo)
return StartStopInfo.takeError();
setStartStopPasses(*StartStopInfo);
AddIRPass addIRPass(MPM, derived());
// `ProfileSummaryInfo` is always valid.
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
addISelPasses(addIRPass);

AddMachinePass addPass(MFPM);
AddMachinePass addPass(MFPM, derived());
if (auto Err = addCoreISelPasses(addPass))
return std::move(Err);

Expand All @@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
});

addPass(FreeMachineFunctionPass());
return verifyStartStop(*StartStopInfo);
}

template <typename Derived>
void CodeGenPassBuilder<Derived>::setStartStopPasses(
const TargetPassConfig::StartStopInfo &Info) const {
if (!Info.StartPass.empty()) {
Started = false;
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
Count = 0u](StringRef ClassName) mutable {
if (Count == Info.StartInstanceNum) {
if (AfterFlag) {
AfterFlag = false;
Started = true;
}
return Started;
}

auto PassName = PIC->getPassNameForClassName(ClassName);
if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
Started = !Info.StartAfter;

return Started;
});
}

if (!Info.StopPass.empty()) {
Stopped = false;
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
Count = 0u](StringRef ClassName) mutable {
if (Count == Info.StopInstanceNum) {
if (AfterFlag) {
AfterFlag = false;
Stopped = true;
}
return !Stopped;
}

auto PassName = PIC->getPassNameForClassName(ClassName);
if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
Stopped = !Info.StopAfter;
return !Stopped;
});
}
}

template <typename Derived>
Error CodeGenPassBuilder<Derived>::verifyStartStop(
const TargetPassConfig::StartStopInfo &Info) const {
if (Started && Stopped)
return Error::success();

if (!Started)
return make_error<StringError>(
"Can't find start pass \"" +
PIC->getPassNameForClassName(Info.StartPass) + "\".",
std::make_error_code(std::errc::invalid_argument));
if (!Stopped)
return make_error<StringError>(
"Can't find stop pass \"" +
PIC->getPassNameForClassName(Info.StopPass) + "\".",
std::make_error_code(std::errc::invalid_argument));
return Error::success();
}

Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/CodeGen/TargetPassConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "llvm/Pass.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Error.h"
#include <cassert>
#include <string>

Expand Down Expand Up @@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
static std::string
getLimitedCodeGenPipelineReason(const char *Separator = "/");

struct StartStopInfo {
bool StartAfter;
bool StopAfter;
unsigned StartInstanceNum;
unsigned StopInstanceNum;
StringRef StartPass;
StringRef StopPass;
};

/// Returns pass name in `-stop-before` or `-stop-after`
/// NOTE: New pass manager migration only
static Expected<StartStopInfo>
getStartStopInfo(PassInstrumentationCallbacks &PIC);

void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }

bool getEnableTailMerge() const { return EnableTailMerge; }
Expand Down
109 changes: 33 additions & 76 deletions llvm/lib/CodeGen/TargetPassConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,81 +504,6 @@ CGPassBuilderOption llvm::getCGPassBuilderOption() {
return Opt;
}

static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
LLVMTargetMachine &LLVMTM) {
StringRef StartBefore;
StringRef StartAfter;
StringRef StopBefore;
StringRef StopAfter;

unsigned StartBeforeInstanceNum = 0;
unsigned StartAfterInstanceNum = 0;
unsigned StopBeforeInstanceNum = 0;
unsigned StopAfterInstanceNum = 0;

std::tie(StartBefore, StartBeforeInstanceNum) =
getPassNameAndInstanceNum(StartBeforeOpt);
std::tie(StartAfter, StartAfterInstanceNum) =
getPassNameAndInstanceNum(StartAfterOpt);
std::tie(StopBefore, StopBeforeInstanceNum) =
getPassNameAndInstanceNum(StopBeforeOpt);
std::tie(StopAfter, StopAfterInstanceNum) =
getPassNameAndInstanceNum(StopAfterOpt);

if (StartBefore.empty() && StartAfter.empty() && StopBefore.empty() &&
StopAfter.empty())
return;

std::tie(StartBefore, std::ignore) =
LLVMTM.getPassNameFromLegacyName(StartBefore);
std::tie(StartAfter, std::ignore) =
LLVMTM.getPassNameFromLegacyName(StartAfter);
std::tie(StopBefore, std::ignore) =
LLVMTM.getPassNameFromLegacyName(StopBefore);
std::tie(StopAfter, std::ignore) =
LLVMTM.getPassNameFromLegacyName(StopAfter);
if (!StartBefore.empty() && !StartAfter.empty())
report_fatal_error(Twine(StartBeforeOptName) + Twine(" and ") +
Twine(StartAfterOptName) + Twine(" specified!"));
if (!StopBefore.empty() && !StopAfter.empty())
report_fatal_error(Twine(StopBeforeOptName) + Twine(" and ") +
Twine(StopAfterOptName) + Twine(" specified!"));

PIC.registerShouldRunOptionalPassCallback(
[=, EnableCurrent = StartBefore.empty() && StartAfter.empty(),
EnableNext = std::optional<bool>(), StartBeforeCount = 0u,
StartAfterCount = 0u, StopBeforeCount = 0u,
StopAfterCount = 0u](StringRef P, Any) mutable {
bool StartBeforePass = !StartBefore.empty() && P.contains(StartBefore);
bool StartAfterPass = !StartAfter.empty() && P.contains(StartAfter);
bool StopBeforePass = !StopBefore.empty() && P.contains(StopBefore);
bool StopAfterPass = !StopAfter.empty() && P.contains(StopAfter);

// Implement -start-after/-stop-after
if (EnableNext) {
EnableCurrent = *EnableNext;
EnableNext.reset();
}

// Using PIC.registerAfterPassCallback won't work because if this
// callback returns false, AfterPassCallback is also skipped.
if (StartAfterPass && StartAfterCount++ == StartAfterInstanceNum) {
assert(!EnableNext && "Error: assign to EnableNext more than once");
EnableNext = true;
}
if (StopAfterPass && StopAfterCount++ == StopAfterInstanceNum) {
assert(!EnableNext && "Error: assign to EnableNext more than once");
EnableNext = false;
}

if (StartBeforePass && StartBeforeCount++ == StartBeforeInstanceNum)
EnableCurrent = true;
if (StopBeforePass && StopBeforeCount++ == StopBeforeInstanceNum)
EnableCurrent = false;
return EnableCurrent;
});
}

void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
LLVMTargetMachine &LLVMTM) {

Expand All @@ -605,8 +530,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,

return true;
});
}

registerPartialPipelineCallback(PIC, LLVMTM);
Expected<TargetPassConfig::StartStopInfo>
TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
auto [StartBefore, StartBeforeInstanceNum] =
getPassNameAndInstanceNum(StartBeforeOpt);
auto [StartAfter, StartAfterInstanceNum] =
getPassNameAndInstanceNum(StartAfterOpt);
auto [StopBefore, StopBeforeInstanceNum] =
getPassNameAndInstanceNum(StopBeforeOpt);
auto [StopAfter, StopAfterInstanceNum] =
getPassNameAndInstanceNum(StopAfterOpt);

if (!StartBefore.empty() && !StartAfter.empty())
return make_error<StringError>(
Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
std::make_error_code(std::errc::invalid_argument));
if (!StopBefore.empty() && !StopAfter.empty())
return make_error<StringError>(
Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
std::make_error_code(std::errc::invalid_argument));

StartStopInfo Result;
Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
Result.StartInstanceNum =
StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
Result.StopInstanceNum =
StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
Result.StartAfter = !StartAfter.empty();
Result.StopAfter = !StopAfter.empty();
Result.StartInstanceNum += Result.StartInstanceNum == 0;
Result.StopInstanceNum += Result.StopInstanceNum == 0;
return Result;
}

// Out of line constructor provides default values for pass options and
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
#include "llvm/CodeGen/StackProtector.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TypePromotion.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
Expand Down Expand Up @@ -316,7 +317,8 @@ namespace {
/// We currently only use this for --print-before/after.
bool shouldPopulateClassToPassNames() {
return PrintPipelinePasses || !printBeforePasses().empty() ||
!printAfterPasses().empty() || !isFilterPassesEmpty();
!printAfterPasses().empty() || !isFilterPassesEmpty() ||
TargetPassConfig::hasLimitedCodeGenPipeline();
}

// A pass for testing -print-on-crash.
Expand Down