diff options
Diffstat (limited to 'llvm/lib/Analysis/SyntheticCountsUtils.cpp')
-rw-r--r-- | llvm/lib/Analysis/SyntheticCountsUtils.cpp | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/llvm/lib/Analysis/SyntheticCountsUtils.cpp new file mode 100644 index 00000000000..262299c5f3b --- /dev/null +++ b/llvm/lib/Analysis/SyntheticCountsUtils.cpp @@ -0,0 +1,122 @@ +//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities for propagating synthetic counts. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SyntheticCountsUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +// Given a set of functions in an SCC, propagate entry counts to functions +// called by the SCC. +static void +propagateFromSCC(const SmallPtrSetImpl<Function *> &SCCFunctions, + function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq, + function_ref<uint64_t(Function *F)> GetCount, + function_ref<void(Function *F, uint64_t)> AddToCount) { + + SmallVector<CallSite, 16> CallSites; + + // Gather all callsites in the SCC. + auto GatherCallSites = [&]() { + for (auto *F : SCCFunctions) { + assert(F && !F->isDeclaration()); + for (auto &I : instructions(F)) { + if (auto CS = CallSite(&I)) { + CallSites.push_back(CS); + } + } + } + }; + + GatherCallSites(); + + // Partition callsites so that the callsites that call functions in the same + // SCC come first. + auto Mid = partition(CallSites, [&](CallSite &CS) { + auto *Callee = CS.getCalledFunction(); + if (Callee) + return SCCFunctions.count(Callee); + // FIXME: Use the !callees metadata to propagate counts through indirect + // calls. + return 0U; + }); + + // For functions in the same SCC, update the counts in two steps: + // 1. Compute the additional count for each function by propagating the counts + // along all incoming edges to the function that originate from the same SCC + // and summing them up. + // 2. Add the additional counts to the functions in the SCC. + // This ensures that the order of + // traversal of functions within the SCC doesn't change the final result. + + DenseMap<Function *, uint64_t> AdditionalCounts; + for (auto It = CallSites.begin(); It != Mid; It++) { + auto &CS = *It; + auto RelFreq = GetCallSiteRelFreq(CS); + Function *Callee = CS.getCalledFunction(); + Function *Caller = CS.getCaller(); + RelFreq *= Scaled64(GetCount(Caller), 0); + uint64_t AdditionalCount = RelFreq.toInt<uint64_t>(); + AdditionalCounts[Callee] += AdditionalCount; + } + + // Update the counts for the functions in the SCC. + for (auto &Entry : AdditionalCounts) + AddToCount(Entry.first, Entry.second); + + // Now update the counts for functions not in SCC. + for (auto It = Mid; It != CallSites.end(); It++) { + auto &CS = *It; + auto Weight = GetCallSiteRelFreq(CS); + Function *Callee = CS.getCalledFunction(); + Function *Caller = CS.getCaller(); + Weight *= Scaled64(GetCount(Caller), 0); + AddToCount(Callee, Weight.toInt<uint64_t>()); + } +} + +/// Propgate synthetic entry counts on a callgraph. +/// +/// This performs a reverse post-order traversal of the callgraph SCC. For each +/// SCC, it first propagates the entry counts to the functions within the SCC +/// through call edges and updates them in one shot. Then the entry counts are +/// propagated to functions outside the SCC. +void llvm::propagateSyntheticCounts( + const CallGraph &CG, function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq, + function_ref<uint64_t(Function *F)> GetCount, + function_ref<void(Function *F, uint64_t)> AddToCount) { + + SmallVector<SmallPtrSet<Function *, 8>, 16> SCCs; + for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) { + auto SCC = *I; + + SmallPtrSet<Function *, 8> SCCFunctions; + for (auto *Node : SCC) { + Function *F = Node->getFunction(); + if (F && !F->isDeclaration()) { + SCCFunctions.insert(F); + } + } + SCCs.push_back(SCCFunctions); + } + + for (auto &SCCFunctions : reverse(SCCs)) + propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount); +} |