summaryrefslogtreecommitdiffstats
path: root/clang/lib/CodeGen/CodeGenPGO.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/CodeGen/CodeGenPGO.cpp')
-rw-r--r--clang/lib/CodeGen/CodeGenPGO.cpp456
1 files changed, 456 insertions, 0 deletions
diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp
new file mode 100644
index 00000000000..d1b57d80c35
--- /dev/null
+++ b/clang/lib/CodeGen/CodeGenPGO.cpp
@@ -0,0 +1,456 @@
+//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Instrumentation-based profile-guided optimization
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodeGenPGO.h"
+#include "CodeGenFunction.h"
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/AST/StmtVisitor.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/Support/FileSystem.h"
+
+using namespace clang;
+using namespace CodeGen;
+
+static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
+ DiagnosticsEngine &Diags = CGM.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, Message);
+ Diags.Report(DiagID);
+}
+
+PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
+ : CGM(CGM) {
+ if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
+ ReportBadPGOData(CGM, "failed to open pgo data file");
+ return;
+ }
+
+ if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
+ ReportBadPGOData(CGM, "pgo data file too big");
+ return;
+ }
+
+ // Scan through the data file and map each function to the corresponding
+ // file offset where its counts are stored.
+ const char *BufferStart = DataBuffer->getBufferStart();
+ const char *BufferEnd = DataBuffer->getBufferEnd();
+ const char *CurPtr = BufferStart;
+ while (CurPtr < BufferEnd) {
+ // Read the mangled function name.
+ const char *FuncName = CurPtr;
+ // FIXME: Something will need to be added to distinguish static functions.
+ CurPtr = strchr(CurPtr, ' ');
+ if (!CurPtr) {
+ ReportBadPGOData(CGM, "pgo data file has malformed function entry");
+ return;
+ }
+ StringRef MangledName(FuncName, CurPtr - FuncName);
+
+ // Read the number of counters.
+ char *EndPtr;
+ unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
+ if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
+ ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
+ return;
+ }
+ CurPtr = EndPtr;
+
+ // There is one line for each counter; skip over those lines.
+ for (unsigned N = 0; N < NumCounters; ++N) {
+ CurPtr = strchr(++CurPtr, '\n');
+ if (!CurPtr) {
+ ReportBadPGOData(CGM, "pgo data file is missing some counter info");
+ return;
+ }
+ }
+
+ // Skip over the blank line separating functions.
+ CurPtr += 2;
+
+ DataOffsets[MangledName] = FuncName - BufferStart;
+ }
+}
+
+bool PGOProfileData::getFunctionCounts(StringRef MangledName,
+ std::vector<uint64_t> &Counts) {
+ // Find the relevant section of the pgo-data file.
+ llvm::StringMap<unsigned>::const_iterator OffsetIter =
+ DataOffsets.find(MangledName);
+ if (OffsetIter == DataOffsets.end())
+ return true;
+ const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
+
+ // Skip over the function name.
+ CurPtr = strchr(CurPtr, ' ');
+ assert(CurPtr && "pgo-data has corrupted function entry");
+
+ // Read the number of counters.
+ char *EndPtr;
+ unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
+ assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
+ "pgo-data file has corrupted number of counters");
+ CurPtr = EndPtr;
+
+ Counts.reserve(NumCounters);
+
+ for (unsigned N = 0; N < NumCounters; ++N) {
+ // Read the count value.
+ uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
+ if (EndPtr == CurPtr || *EndPtr != '\n') {
+ ReportBadPGOData(CGM, "pgo-data file has bad count value");
+ return true;
+ }
+ Counts.push_back(Count);
+ CurPtr = EndPtr + 1;
+ }
+
+ // Make sure the number of counters matches up.
+ if (Counts.size() != NumCounters) {
+ ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
+ return true;
+ }
+
+ return false;
+}
+
+void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
+ if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
+ return;
+
+ llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+
+ llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
+ llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
+
+ llvm::Function *WriteoutF =
+ CGM.getModule().getFunction("__llvm_pgo_writeout");
+ if (!WriteoutF) {
+ llvm::FunctionType *WriteoutFTy =
+ llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
+ WriteoutF = llvm::Function::Create(WriteoutFTy,
+ llvm::GlobalValue::InternalLinkage,
+ "__llvm_pgo_writeout", &CGM.getModule());
+ }
+ WriteoutF->setUnnamedAddr(true);
+ WriteoutF->addFnAttr(llvm::Attribute::NoInline);
+ if (CGM.getCodeGenOpts().DisableRedZone)
+ WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
+
+ llvm::BasicBlock *BB = WriteoutF->empty() ?
+ llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
+
+ CGBuilderTy PGOBuilder(BB);
+
+ llvm::Instruction *I = BB->getTerminator();
+ if (!I)
+ I = PGOBuilder.CreateRetVoid();
+ PGOBuilder.SetInsertPoint(I);
+
+ llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
+ llvm::Type *Args[] = {
+ Int8PtrTy, // const char *MangledName
+ Int32Ty, // uint32_t NumCounters
+ Int64PtrTy // uint64_t *Counters
+ };
+ llvm::FunctionType *FTy =
+ llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
+ llvm::Constant *EmitFunc =
+ CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
+
+ llvm::Constant *MangledName =
+ CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
+ MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
+ PGOBuilder.CreateCall3(EmitFunc, MangledName,
+ PGOBuilder.getInt32(NumRegionCounters),
+ PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
+}
+
+llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
+ llvm::Function *WriteoutF =
+ CGM.getModule().getFunction("__llvm_pgo_writeout");
+ if (!WriteoutF)
+ return NULL;
+
+ // Create a small bit of code that registers the "__llvm_pgo_writeout" to
+ // be executed at exit.
+ llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
+ if (F)
+ return NULL;
+
+ llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+ llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
+ false);
+ F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
+ "__llvm_pgo_init", &CGM.getModule());
+ F->setUnnamedAddr(true);
+ F->setLinkage(llvm::GlobalValue::InternalLinkage);
+ F->addFnAttr(llvm::Attribute::NoInline);
+ if (CGM.getCodeGenOpts().DisableRedZone)
+ F->addFnAttr(llvm::Attribute::NoRedZone);
+
+ llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
+ CGBuilderTy PGOBuilder(BB);
+
+ FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
+ llvm::Type *Params[] = {
+ llvm::PointerType::get(FTy, 0)
+ };
+ FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
+
+ // Inialize the environment and register the local writeout function.
+ llvm::Constant *PGOInit =
+ CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
+ PGOBuilder.CreateCall(PGOInit, WriteoutF);
+ PGOBuilder.CreateRetVoid();
+
+ return F;
+}
+
+namespace {
+ /// A StmtVisitor that fills a map of statements to PGO counters.
+ struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
+ /// The next counter value to assign.
+ unsigned NextCounter;
+ /// The map of statements to counters.
+ llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
+
+ MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
+ NextCounter(0), CounterMap(CounterMap) {
+ }
+
+ void VisitChildren(const Stmt *S) {
+ for (Stmt::const_child_range I = S->children(); I; ++I)
+ if (*I)
+ this->Visit(*I);
+ }
+ void VisitStmt(const Stmt *S) { VisitChildren(S); }
+
+ /// Assign a counter to track entry to the function body
+ void VisitFunctionDecl(const FunctionDecl *S) {
+ (*CounterMap)[S->getBody()] = NextCounter++;
+ Visit(S->getBody());
+ }
+ /// Assign a counter to track the block following a label
+ void VisitLabelStmt(const LabelStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getSubStmt());
+ }
+ /// Assign three counters - one for the body of the loop, one for breaks
+ /// from the loop, and one for continues.
+ ///
+ /// The break and continue counters cover all such statements in this loop,
+ /// and are used in calculations to find the number of times the condition
+ /// and exit of the loop occur. They are needed so we can differentiate
+ /// these statements from non-local exits like return and goto.
+ void VisitWhileStmt(const WhileStmt *S) {
+ (*CounterMap)[S] = NextCounter;
+ NextCounter += 3;
+ Visit(S->getCond());
+ Visit(S->getBody());
+ }
+ /// Assign counters for the body of the loop, and for breaks and
+ /// continues. See VisitWhileStmt.
+ void VisitDoStmt(const DoStmt *S) {
+ (*CounterMap)[S] = NextCounter;
+ NextCounter += 3;
+ Visit(S->getBody());
+ Visit(S->getCond());
+ }
+ /// Assign counters for the body of the loop, and for breaks and
+ /// continues. See VisitWhileStmt.
+ void VisitForStmt(const ForStmt *S) {
+ (*CounterMap)[S] = NextCounter;
+ NextCounter += 3;
+ const Expr *E;
+ if ((E = S->getCond()))
+ Visit(E);
+ Visit(S->getBody());
+ if ((E = S->getInc()))
+ Visit(E);
+ }
+ /// Assign counters for the body of the loop, and for breaks and
+ /// continues. See VisitWhileStmt.
+ void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
+ (*CounterMap)[S] = NextCounter;
+ NextCounter += 3;
+ const Expr *E;
+ if ((E = S->getCond()))
+ Visit(E);
+ Visit(S->getBody());
+ if ((E = S->getInc()))
+ Visit(E);
+ }
+ /// Assign counters for the body of the loop, and for breaks and
+ /// continues. See VisitWhileStmt.
+ void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
+ (*CounterMap)[S] = NextCounter;
+ NextCounter += 3;
+ Visit(S->getElement());
+ Visit(S->getBody());
+ }
+ /// Assign a counter for the exit block of the switch statement.
+ void VisitSwitchStmt(const SwitchStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getCond());
+ Visit(S->getBody());
+ }
+ /// Assign a counter for a particular case in a switch. This counts jumps
+ /// from the switch header as well as fallthrough from the case before this
+ /// one.
+ void VisitCaseStmt(const CaseStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getSubStmt());
+ }
+ /// Assign a counter for the default case of a switch statement. The count
+ /// is the number of branches from the loop header to the default, and does
+ /// not include fallthrough from previous cases. If we have multiple
+ /// conditional branch blocks from the switch instruction to the default
+ /// block, as with large GNU case ranges, this is the counter for the last
+ /// edge in that series, rather than the first.
+ void VisitDefaultStmt(const DefaultStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getSubStmt());
+ }
+ /// Assign a counter for the "then" part of an if statement. The count for
+ /// the "else" part, if it exists, will be calculated from this counter.
+ void VisitIfStmt(const IfStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getCond());
+ Visit(S->getThen());
+ if (S->getElse())
+ Visit(S->getElse());
+ }
+ /// Assign a counter for the continuation block of a C++ try statement.
+ void VisitCXXTryStmt(const CXXTryStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getTryBlock());
+ for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
+ Visit(S->getHandler(I));
+ }
+ /// Assign a counter for a catch statement's handler block.
+ void VisitCXXCatchStmt(const CXXCatchStmt *S) {
+ (*CounterMap)[S] = NextCounter++;
+ Visit(S->getHandlerBlock());
+ }
+ /// Assign a counter for the "true" part of a conditional operator. The
+ /// count in the "false" part will be calculated from this counter.
+ void VisitConditionalOperator(const ConditionalOperator *E) {
+ (*CounterMap)[E] = NextCounter++;
+ Visit(E->getCond());
+ Visit(E->getTrueExpr());
+ Visit(E->getFalseExpr());
+ }
+ /// Assign a counter for the right hand side of a logical and operator.
+ void VisitBinLAnd(const BinaryOperator *E) {
+ (*CounterMap)[E] = NextCounter++;
+ Visit(E->getLHS());
+ Visit(E->getRHS());
+ }
+ /// Assign a counter for the right hand side of a logical or operator.
+ void VisitBinLOr(const BinaryOperator *E) {
+ (*CounterMap)[E] = NextCounter++;
+ Visit(E->getLHS());
+ Visit(E->getRHS());
+ }
+ };
+}
+
+void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
+ bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
+ PGOProfileData *PGOData = CGM.getPGOData();
+ if (!InstrumentRegions && !PGOData)
+ return;
+ const Decl *D = GD.getDecl();
+ if (!D)
+ return;
+ mapRegionCounters(D);
+ if (InstrumentRegions)
+ emitCounterVariables();
+ if (PGOData)
+ loadRegionCounts(GD, PGOData);
+}
+
+void CodeGenPGO::mapRegionCounters(const Decl *D) {
+ RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
+ MapRegionCounters Walker(RegionCounterMap);
+ if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
+ Walker.VisitFunctionDecl(FD);
+ NumRegionCounters = Walker.NextCounter;
+}
+
+void CodeGenPGO::emitCounterVariables() {
+ llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+ llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
+ NumRegionCounters);
+ RegionCounters =
+ new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
+ llvm::GlobalVariable::PrivateLinkage,
+ llvm::Constant::getNullValue(CounterTy),
+ "__llvm_pgo_ctr");
+}
+
+void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
+ if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
+ return;
+ llvm::Value *Addr =
+ Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
+ llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
+ Count = Builder.CreateAdd(Count, Builder.getInt64(1));
+ Builder.CreateStore(Count, Addr);
+}
+
+void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
+ // For now, ignore the counts from the PGO data file only if the number of
+ // counters does not match. This could be tightened down in the future to
+ // ignore counts when the input changes in various ways, e.g., by comparing a
+ // hash value based on some characteristics of the input.
+ RegionCounts = new std::vector<uint64_t>();
+ if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
+ RegionCounts->size() != NumRegionCounters) {
+ delete RegionCounts;
+ RegionCounts = 0;
+ }
+}
+
+void CodeGenPGO::destroyRegionCounters() {
+ if (RegionCounterMap != 0)
+ delete RegionCounterMap;
+ if (RegionCounts != 0)
+ delete RegionCounts;
+}
+
+llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
+ uint64_t FalseCount) {
+ if (!TrueCount && !FalseCount)
+ return 0;
+
+ llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+ // TODO: need to scale down to 32-bits
+ // According to Laplace's Rule of Succession, it is better to compute the
+ // weight based on the count plus 1.
+ return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
+}
+
+llvm::MDNode *
+CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
+ llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+ // TODO: need to scale down to 32-bits, instead of just truncating.
+ // According to Laplace's Rule of Succession, it is better to compute the
+ // weight based on the count plus 1.
+ SmallVector<uint32_t, 16> ScaledWeights;
+ ScaledWeights.reserve(Weights.size());
+ for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
+ WI != WE; ++WI) {
+ ScaledWeights.push_back(*WI + 1);
+ }
+ return MDHelper.createBranchWeights(ScaledWeights);
+}
OpenPOWER on IntegriCloud