diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h index 8707e83574dd8..43587d953fc4c 100644 --- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h +++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h @@ -9,6 +9,8 @@ #ifndef LLVM_ANALYSIS_CTXPROFANALYSIS_H #define LLVM_ANALYSIS_CTXPROFANALYSIS_H +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PassManager.h" #include "llvm/ProfileData/PGOCtxProfReader.h" @@ -82,6 +84,8 @@ class CtxProfAnalysis : public AnalysisInfoMixin { using Result = PGOContextualProfile; PGOContextualProfile run(Module &M, ModuleAnalysisManager &MAM); + + static InstrProfCallsite *getCallsiteInstrumentation(CallBase &CB); }; class CtxProfAnalysisPrinterPass diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp index d0ccf4ba537f8..51663196b1307 100644 --- a/llvm/lib/Analysis/CtxProfAnalysis.cpp +++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp @@ -186,3 +186,10 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M, OS << "\n"; return PreservedAnalyses::all(); } + +InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) { + while (auto *Prev = CB.getPrevNode()) + if (auto *IPC = dyn_cast(Prev)) + return IPC; + return nullptr; +} diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt index d9eb81faac42a..1dec41972b357 100644 --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -22,6 +22,7 @@ set(ANALYSIS_TEST_SOURCES CFGTest.cpp CGSCCPassManagerTest.cpp ConstraintSystemTest.cpp + CtxProfAnalysisTest.cpp DDGTest.cpp DomTreeUpdaterTest.cpp DXILResourceTest.cpp diff --git a/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp new file mode 100644 index 0000000000000..5f9bf3ec540eb --- /dev/null +++ b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp @@ -0,0 +1,135 @@ +//===--- CtxProfAnalysisTest.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CtxProfAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Analysis.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +class CtxProfAnalysisTest : public testing::Test { + static constexpr auto *IR = R"IR( +declare void @bar() + +define private void @foo(i32 %a, ptr %fct) #0 !guid !0 { + %t = icmp eq i32 %a, 0 + br i1 %t, label %yes, label %no +yes: + call void %fct(i32 %a) + br label %exit +no: + call void @bar() + br label %exit +exit: + ret void +} + +define void @an_entrypoint(i32 %a) { + %t = icmp eq i32 %a, 0 + br i1 %t, label %yes, label %no + +yes: + call void @foo(i32 1, ptr null) + ret void +no: + ret void +} + +define void @another_entrypoint_no_callees(i32 %a) { + %t = icmp eq i32 %a, 0 + br i1 %t, label %yes, label %no + +yes: + ret void +no: + ret void +} + +attributes #0 = { noinline } +!0 = !{ i64 11872291593386833696 } +)IR"; + +protected: + LLVMContext C; + PassBuilder PB; + ModuleAnalysisManager MAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + LoopAnalysisManager LAM; + std::unique_ptr M; + + void SetUp() override { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + ASSERT_TRUE(!!M); + } + +public: + CtxProfAnalysisTest() { + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + } +}; + +TEST_F(CtxProfAnalysisTest, GetCallsiteIDTest) { + ModulePassManager MPM; + MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF)); + EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved()); + auto *F = M->getFunction("foo"); + ASSERT_NE(F, nullptr); + std::vector InsValues; + + for (auto &BB : *F) + for (auto &I : BB) + if (auto *CB = dyn_cast(&I)) { + // Skip instrumentation inserted intrinsics. + if (CB->getCalledFunction() && CB->getCalledFunction()->isIntrinsic()) + continue; + auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB); + ASSERT_NE(Ins, nullptr); + InsValues.push_back(Ins->getIndex()->getZExtValue()); + } + + EXPECT_THAT(InsValues, testing::ElementsAre(0, 1)); +} + +TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) { + auto *F = M->getFunction("foo"); + ASSERT_NE(F, nullptr); + CallBase *FirstCall = nullptr; + for (auto &BB : *F) + for (auto &I : BB) + if (auto *CB = dyn_cast(&I)) { + if (CB->isIndirectCall() || !CB->getCalledFunction()->isIntrinsic()) { + FirstCall = CB; + break; + } + } + ASSERT_NE(FirstCall, nullptr); + auto *IndIns = CtxProfAnalysis::getCallsiteInstrumentation(*FirstCall); + EXPECT_EQ(IndIns, nullptr); +} + +} // namespace