diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp | 20 | ||||
-rw-r--r-- | llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp | 142 |
2 files changed, 151 insertions, 11 deletions
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index 6ca49cf5985..4f71261a1ab 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -107,6 +107,13 @@ StringRef InstrProfiling::getCoverageSection() const { return getInstrProfCoverageSectionName(isMachO()); } +static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { + InstrProfIncrementInst *Inc = dyn_cast<InstrProfIncrementInstStep>(Instr); + if (Inc) + return Inc; + return dyn_cast<InstrProfIncrementInst>(Instr); +} + bool InstrProfiling::run(Module &M) { bool MadeChange = false; @@ -138,7 +145,8 @@ bool InstrProfiling::run(Module &M) { for (BasicBlock &BB : F) for (auto I = BB.begin(), E = BB.end(); I != E;) { auto Instr = I++; - if (auto *Inc = dyn_cast<InstrProfIncrementInst>(Instr)) { + InstrProfIncrementInst *Inc = castToIncrementInst(&*Instr); + if (Inc) { lowerIncrement(Inc); MadeChange = true; } else if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(Instr)) { @@ -214,6 +222,14 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { Ind->eraseFromParent(); } +static Value *getIncrementStep(InstrProfIncrementInst *Inc, + IRBuilder<> &Builder) { + auto *IncWithStep = dyn_cast<InstrProfIncrementInstStep>(Inc); + if (IncWithStep) + return IncWithStep->getStep(); + return Builder.getInt64(1); +} + void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { GlobalVariable *Counters = getOrCreateRegionCounters(Inc); @@ -221,7 +237,7 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { uint64_t Index = Inc->getIndex()->getZExtValue(); Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); Value *Count = Builder.CreateLoad(Addr, "pgocount"); - Count = Builder.CreateAdd(Count, Builder.getInt64(1)); + Count = Builder.CreateAdd(Count, getIncrementStep(Inc, Builder)); Inc->replaceAllUsesWith(Builder.CreateStore(Count, Addr)); Inc->eraseFromParent(); } diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 707329cf67e..f4ffe23420e 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -86,6 +86,7 @@ using namespace llvm; #define DEBUG_TYPE "pgo-instrumentation" STATISTIC(NumOfPGOInstrument, "Number of edges instrumented."); +STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented."); STATISTIC(NumOfPGOEdge, "Number of edges."); STATISTIC(NumOfPGOBB, "Number of basic-blocks."); STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); @@ -133,7 +134,65 @@ static cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden); +// Command line option to enable/disable select instruction instrumentation. +static cl::opt<bool> PGOInstrSelect("pgo-instr-select", cl::init(true), + cl::Hidden); namespace { + +/// The select instruction visitor plays three roles specified +/// by the mode. In \c VM_counting mode, it simply counts the number of +/// select instructions. In \c VM_instrument mode, it inserts code to count +/// the number times TrueValue of select is taken. In \c VM_annotate mode, +/// it reads the profile data and annotate the select instruction with metadata. +enum VisitMode { VM_counting, VM_instrument, VM_annotate }; +class PGOUseFunc; + +/// Instruction Visitor class to visit select instructions. +struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { + Function &F; + unsigned NSIs = 0; // Number of select instructions instrumented. + VisitMode Mode = VM_counting; // Visiting mode. + unsigned *CurCtrIdx = nullptr; // Pointer to current counter index. + unsigned TotalNumCtrs = 0; // Total number of counters + GlobalVariable *FuncNameVar = nullptr; + uint64_t FuncHash = 0; + PGOUseFunc *UseFunc = nullptr; + + SelectInstVisitor(Function &Func) : F(Func) {} + + void countSelects(Function &Func) { + Mode = VM_counting; + visit(Func); + } + // Visit the IR stream and instrument all select instructions. \p + // Ind is a pointer to the counter index variable; \p TotalNC + // is the total number of counters; \p FNV is the pointer to the + // PGO function name var; \p FHash is the function hash. + void instrumentSelects(Function &Func, unsigned *Ind, unsigned TotalNC, + GlobalVariable *FNV, uint64_t FHash) { + Mode = VM_instrument; + CurCtrIdx = Ind; + TotalNumCtrs = TotalNC; + FuncHash = FHash; + FuncNameVar = FNV; + visit(Func); + } + + // Visit the IR stream and annotate all select instructions. + void annotateSelects(Function &Func, PGOUseFunc *UF, unsigned *Ind) { + Mode = VM_annotate; + UseFunc = UF; + CurCtrIdx = Ind; + visit(Func); + } + + void instrumentOneSelectInst(SelectInst &SI); + void annotateOneSelectInst(SelectInst &SI); + // Visit \p SI instruction and perform tasks according to visit mode. + void visitSelectInst(SelectInst &SI); + unsigned getNumOfSelectInsts() const { return NSIs; } +}; + class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; @@ -180,6 +239,7 @@ private: AU.addRequired<BlockFrequencyInfoWrapperPass>(); } }; + } // end anonymous namespace char PGOInstrumentationGenLegacyPass::ID = 0; @@ -254,6 +314,7 @@ private: std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; public: + SelectInstVisitor SIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; // CFG hash value for this function. @@ -280,8 +341,13 @@ public: std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr) - : F(Func), ComdatMembers(ComdatMembers), FunctionHash(0), + : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0), MST(F, BPI, BFI) { + + // This should be done before CFG hash computation. + SIVisitor.countSelects(Func); + NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); + FuncName = getPGOFuncName(F); computeCFGHash(); if (ComdatMembers.size()) @@ -308,7 +374,7 @@ public: if (!E->InMST && !E->Removed) NumCounters++; } - return NumCounters; + return NumCounters + SIVisitor.getNumOfSelectInsts(); } }; @@ -328,7 +394,8 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { } } JC.update(Indexes); - FunctionHash = (uint64_t)findIndirectCallSites(F).size() << 48 | + FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | + (uint64_t)findIndirectCallSites(F).size() << 48 | (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); } @@ -473,6 +540,10 @@ static void instrumentOneFunc( Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters), Builder.getInt32(I++)}); } + + // Now instrument select instructions: + FuncInfo.SIVisitor.instrumentSelects(F, &I, NumCounters, FuncInfo.FuncNameVar, + FuncInfo.FunctionHash); assert(I == NumCounters); if (DisableValueProfiling) @@ -594,17 +665,17 @@ public: // Return the profile record for this function; InstrProfRecord &getProfileRecord() { return ProfileRecord; } + // Return the auxiliary BB information. + UseBBInfo &getBBInfo(const BasicBlock *BB) const { + return FuncInfo.getBBInfo(BB); + } + private: Function &F; Module *M; // This member stores the shared information with class PGOGenFunc. FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo; - // Return the auxiliary BB information. - UseBBInfo &getBBInfo(const BasicBlock *BB) const { - return FuncInfo.getBBInfo(BB); - } - // The maximum count value in the profile. This is only used in PGO use // compilation. uint64_t ProgramMaxCount; @@ -677,6 +748,8 @@ void PGOUseFunc::setInstrumentedCounts( NewEdge1.InMST = true; getBBInfo(InstrBB).setBBInfoCount(CountValue); } + // Now annotate select instructions + FuncInfo.SIVisitor.annotateSelects(F, this, &I); assert(I == CountFromProfile.size()); } @@ -820,7 +893,7 @@ void PGOUseFunc::populateCounters() { DEBUG(FuncInfo.dumpInfo("after reading profile.")); } -static void setProfMetadata(Module *M, TerminatorInst *TI, +static void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { MDBuilder MDB(M->getContext()); assert(MaxCount > 0 && "Bad max count"); @@ -869,6 +942,57 @@ void PGOUseFunc::setBranchWeights() { } } +void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { + Module *M = F.getParent(); + IRBuilder<> Builder(&SI); + Type *Int64Ty = Builder.getInt64Ty(); + Type *I8PtrTy = Builder.getInt8PtrTy(); + auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), + {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + Builder.getInt64(FuncHash), + Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); + ++(*CurCtrIdx); +} + +void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) { + std::vector<uint64_t> &CountFromProfile = UseFunc->getProfileRecord().Counts; + assert(*CurCtrIdx < CountFromProfile.size() && + "Out of bound access of counters"); + uint64_t SCounts[2]; + SCounts[0] = CountFromProfile[*CurCtrIdx]; // True count + ++(*CurCtrIdx); + uint64_t TotalCount = UseFunc->getBBInfo(SI.getParent()).CountValue; + // False Count + SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0); + uint64_t MaxCount = std::max(SCounts[0], SCounts[1]); + setProfMetadata(F.getParent(), &SI, SCounts, MaxCount); +} + +void SelectInstVisitor::visitSelectInst(SelectInst &SI) { + if (!PGOInstrSelect) + return; + // FIXME: do not handle this yet. + if (SI.getCondition()->getType()->isVectorTy()) + return; + + NSIs++; + switch (Mode) { + case VM_counting: + return; + case VM_instrument: + instrumentOneSelectInst(SI); + break; + case VM_annotate: + annotateOneSelectInst(SI); + break; + default: + assert(false && "Unknown visiting mode"); + break; + } +} + // Traverse all the indirect callsites and annotate the instructions. void PGOUseFunc::annotateIndirectCallSites() { if (DisableValueProfiling) |