From ce88f2bcbf328bbfb49b7fd762ac43584b7a2e89 Mon Sep 17 00:00:00 2001 From: PaperChalice Date: Wed, 3 Jan 2024 10:34:33 +0800 Subject: [PATCH] [CodeGen][NewPM] Support start/stop in CodeGen --- .../include/llvm/CodeGen/CodeGenPassBuilder.h | 133 +++++++++++++++--- llvm/include/llvm/CodeGen/TargetPassConfig.h | 15 ++ llvm/lib/CodeGen/TargetPassConfig.cpp | 34 +++++ llvm/lib/Passes/PassBuilder.cpp | 4 +- .../CodeGen/CodeGenPassBuilderTest.cpp | 41 ++++++ 5 files changed, 206 insertions(+), 21 deletions(-) diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h index 0ea81347638e9..f425177ba2a90 100644 --- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h +++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h @@ -43,6 +43,7 @@ #include "llvm/CodeGen/ShadowStackGCLowering.h" #include "llvm/CodeGen/SjLjEHPrepare.h" #include "llvm/CodeGen/StackProtector.h" +#include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/UnreachableBlockElim.h" #include "llvm/CodeGen/WasmEHPrepare.h" #include "llvm/CodeGen/WinEHPrepare.h" @@ -175,73 +176,80 @@ template class CodeGenPassBuilder { // Function object to maintain state while adding codegen IR passes. class AddIRPass { public: - AddIRPass(ModulePassManager &MPM) : MPM(MPM) {} + AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {} ~AddIRPass() { if (!FPM.isEmpty()) MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); } - template void operator()(PassT &&Pass) { + template + void operator()(PassT &&Pass, StringRef Name = PassT::name()) { static_assert((is_detected::value || is_detected::value) && "Only module pass and function pass are supported."); + if (!PB.runBeforeAdding(Name)) + return; + // Add Function Pass if constexpr (is_detected::value) { FPM.addPass(std::forward(Pass)); + + for (auto &C : PB.AfterCallbacks) + C(Name); } else { // Add Module Pass if (!FPM.isEmpty()) { MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); FPM = FunctionPassManager(); } + MPM.addPass(std::forward(Pass)); + + for (auto &C : PB.AfterCallbacks) + C(Name); } } private: ModulePassManager &MPM; FunctionPassManager FPM; + const DerivedT &PB; }; // Function object to maintain state while adding codegen machine passes. class AddMachinePass { public: - AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {} + AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB) + : PM(PM), PB(PB) {} template void operator()(PassT &&Pass) { static_assert( is_detected::value, "Machine function pass must define a static member variable `Key`."); - for (auto &C : BeforeCallbacks) - if (!C(&PassT::Key)) - return; + + if (!PB.runBeforeAdding(PassT::name())) + return; + PM.addPass(std::forward(Pass)); - for (auto &C : AfterCallbacks) - C(&PassT::Key); + + for (auto &C : PB.AfterCallbacks) + C(PassT::name()); } template void insertPass(MachinePassKey *ID, PassT Pass) { - AfterCallbacks.emplace_back( + PB.AfterCallbacks.emplace_back( [this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) { if (PassID == ID) this->PM.addPass(std::move(Pass)); }); } - void disablePass(MachinePassKey *ID) { - BeforeCallbacks.emplace_back( - [ID](MachinePassKey *PassID) { return PassID != ID; }); - } - MachineFunctionPassManager releasePM() { return std::move(PM); } private: MachineFunctionPassManager ± - SmallVector, 4> - BeforeCallbacks; - SmallVector, 4> - AfterCallbacks; + const DerivedT &PB; }; LLVMTargetMachine &TM; @@ -469,6 +477,25 @@ template class CodeGenPassBuilder { const DerivedT &derived() const { return static_cast(*this); } + + bool runBeforeAdding(StringRef Name) const { + bool ShouldAdd = true; + for (auto &C : BeforeCallbacks) + ShouldAdd &= C(Name); + return ShouldAdd; + } + + void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const; + + Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const; + + mutable SmallVector, 4> + BeforeCallbacks; + mutable SmallVector, 4> AfterCallbacks; + + /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after` + mutable bool Started = true; + mutable bool Stopped = true; }; template @@ -476,13 +503,17 @@ Error CodeGenPassBuilder::buildPipeline( ModulePassManager &MPM, MachineFunctionPassManager &MFPM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut, CodeGenFileType FileType) const { - AddIRPass addIRPass(MPM); + auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC); + if (!StartStopInfo) + return StartStopInfo.takeError(); + setStartStopPasses(*StartStopInfo); + AddIRPass addIRPass(MPM, derived()); // `ProfileSummaryInfo` is always valid. addIRPass(RequireAnalysisPass()); addIRPass(RequireAnalysisPass()); addISelPasses(addIRPass); - AddMachinePass addPass(MFPM); + AddMachinePass addPass(MFPM, derived()); if (auto Err = addCoreISelPasses(addPass)) return std::move(Err); @@ -495,6 +526,68 @@ Error CodeGenPassBuilder::buildPipeline( }); addPass(FreeMachineFunctionPass()); + return verifyStartStop(*StartStopInfo); +} + +template +void CodeGenPassBuilder::setStartStopPasses( + const TargetPassConfig::StartStopInfo &Info) const { + if (!Info.StartPass.empty()) { + Started = false; + BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter, + Count = 0u](StringRef ClassName) mutable { + if (Count == Info.StartInstanceNum) { + if (AfterFlag) { + AfterFlag = false; + Started = true; + } + return Started; + } + + auto PassName = PIC->getPassNameForClassName(ClassName); + if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum) + Started = !Info.StartAfter; + + return Started; + }); + } + + if (!Info.StopPass.empty()) { + Stopped = false; + BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter, + Count = 0u](StringRef ClassName) mutable { + if (Count == Info.StopInstanceNum) { + if (AfterFlag) { + AfterFlag = false; + Stopped = true; + } + return !Stopped; + } + + auto PassName = PIC->getPassNameForClassName(ClassName); + if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum) + Stopped = !Info.StopAfter; + return !Stopped; + }); + } +} + +template +Error CodeGenPassBuilder::verifyStartStop( + const TargetPassConfig::StartStopInfo &Info) const { + if (Started && Stopped) + return Error::success(); + + if (!Started) + return make_error( + "Can't find start pass \"" + + PIC->getPassNameForClassName(Info.StartPass) + "\".", + std::make_error_code(std::errc::invalid_argument)); + if (!Stopped) + return make_error( + "Can't find stop pass \"" + + PIC->getPassNameForClassName(Info.StopPass) + "\".", + std::make_error_code(std::errc::invalid_argument)); return Error::success(); } diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h index 66365419aa330..de6a760c4e4fd 100644 --- a/llvm/include/llvm/CodeGen/TargetPassConfig.h +++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h @@ -15,6 +15,7 @@ #include "llvm/Pass.h" #include "llvm/Support/CodeGen.h" +#include "llvm/Support/Error.h" #include #include @@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass { static std::string getLimitedCodeGenPipelineReason(const char *Separator = "/"); + struct StartStopInfo { + bool StartAfter; + bool StopAfter; + unsigned StartInstanceNum; + unsigned StopInstanceNum; + StringRef StartPass; + StringRef StopPass; + }; + + /// Returns pass name in `-stop-before` or `-stop-after` + /// NOTE: New pass manager migration only + static Expected + getStartStopInfo(PassInstrumentationCallbacks &PIC); + void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); } bool getEnableTailMerge() const { return EnableTailMerge; } diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index 3bbc792f4cbf4..52cf6b84f3272 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -609,6 +609,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC, registerPartialPipelineCallback(PIC, LLVMTM); } +Expected +TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) { + auto [StartBefore, StartBeforeInstanceNum] = + getPassNameAndInstanceNum(StartBeforeOpt); + auto [StartAfter, StartAfterInstanceNum] = + getPassNameAndInstanceNum(StartAfterOpt); + auto [StopBefore, StopBeforeInstanceNum] = + getPassNameAndInstanceNum(StopBeforeOpt); + auto [StopAfter, StopAfterInstanceNum] = + getPassNameAndInstanceNum(StopAfterOpt); + + if (!StartBefore.empty() && !StartAfter.empty()) + return make_error( + Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!", + std::make_error_code(std::errc::invalid_argument)); + if (!StopBefore.empty() && !StopAfter.empty()) + return make_error( + Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!", + std::make_error_code(std::errc::invalid_argument)); + + StartStopInfo Result; + Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore; + Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore; + Result.StartInstanceNum = + StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum; + Result.StopInstanceNum = + StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum; + Result.StartAfter = !StartAfter.empty(); + Result.StopAfter = !StopAfter.empty(); + Result.StartInstanceNum += Result.StartInstanceNum == 0; + Result.StopInstanceNum += Result.StopInstanceNum == 0; + return Result; +} + // Out of line constructor provides default values for pass options and // registers all common codegen passes. TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm) diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index d0f3a55a12b05..8843d9bd984ee 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -92,6 +92,7 @@ #include "llvm/CodeGen/ShadowStackGCLowering.h" #include "llvm/CodeGen/SjLjEHPrepare.h" #include "llvm/CodeGen/StackProtector.h" +#include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TypePromotion.h" #include "llvm/CodeGen/WasmEHPrepare.h" #include "llvm/CodeGen/WinEHPrepare.h" @@ -315,7 +316,8 @@ namespace { /// We currently only use this for --print-before/after. bool shouldPopulateClassToPassNames() { return PrintPipelinePasses || !printBeforePasses().empty() || - !printAfterPasses().empty() || !isFilterPassesEmpty(); + !printAfterPasses().empty() || !isFilterPassesEmpty() || + TargetPassConfig::hasLimitedCodeGenPipeline(); } // A pass for testing -print-on-crash. diff --git a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp index d6ec393155cf0..63499b056d1ef 100644 --- a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp +++ b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp @@ -138,4 +138,45 @@ TEST_F(CodeGenPassBuilderTest, basic) { EXPECT_EQ(MIRPipeline, ExpectedMIRPipeline); } +// TODO: Move this to lit test when llc support new pm. +TEST_F(CodeGenPassBuilderTest, start_stop) { + static const char *argv[] = { + "test", + "-start-after=no-op-module", + "-stop-before=no-op-function,2", + }; + int argc = std::size(argv); + cl::ParseCommandLineOptions(argc, argv); + + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + + PassInstrumentationCallbacks PIC; + DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC); + PipelineTuningOptions PTO; + PassBuilder PB(TM.get(), PTO, std::nullopt, &PIC); + + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + ModulePassManager MPM; + MachineFunctionPassManager MFPM; + + Error Err = + CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null); + EXPECT_FALSE(Err); + std::string IRPipeline; + raw_string_ostream IROS(IRPipeline); + MPM.printPipeline(IROS, [&PIC](StringRef Name) { + auto PassName = PIC.getPassNameForClassName(Name); + return PassName.empty() ? Name : PassName; + }); + EXPECT_EQ(IRPipeline, "function(no-op-function)"); +} + } // namespace