summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Analysis/SyntheticCountsUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/SyntheticCountsUtils.cpp')
-rw-r--r--llvm/lib/Analysis/SyntheticCountsUtils.cpp122
1 files changed, 122 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/llvm/lib/Analysis/SyntheticCountsUtils.cpp
new file mode 100644
index 00000000000..262299c5f3b
--- /dev/null
+++ b/llvm/lib/Analysis/SyntheticCountsUtils.cpp
@@ -0,0 +1,122 @@
+//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities for propagating synthetic counts.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/SyntheticCountsUtils.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+
+using namespace llvm;
+
+// Given a set of functions in an SCC, propagate entry counts to functions
+// called by the SCC.
+static void
+propagateFromSCC(const SmallPtrSetImpl<Function *> &SCCFunctions,
+ function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
+ function_ref<uint64_t(Function *F)> GetCount,
+ function_ref<void(Function *F, uint64_t)> AddToCount) {
+
+ SmallVector<CallSite, 16> CallSites;
+
+ // Gather all callsites in the SCC.
+ auto GatherCallSites = [&]() {
+ for (auto *F : SCCFunctions) {
+ assert(F && !F->isDeclaration());
+ for (auto &I : instructions(F)) {
+ if (auto CS = CallSite(&I)) {
+ CallSites.push_back(CS);
+ }
+ }
+ }
+ };
+
+ GatherCallSites();
+
+ // Partition callsites so that the callsites that call functions in the same
+ // SCC come first.
+ auto Mid = partition(CallSites, [&](CallSite &CS) {
+ auto *Callee = CS.getCalledFunction();
+ if (Callee)
+ return SCCFunctions.count(Callee);
+ // FIXME: Use the !callees metadata to propagate counts through indirect
+ // calls.
+ return 0U;
+ });
+
+ // For functions in the same SCC, update the counts in two steps:
+ // 1. Compute the additional count for each function by propagating the counts
+ // along all incoming edges to the function that originate from the same SCC
+ // and summing them up.
+ // 2. Add the additional counts to the functions in the SCC.
+ // This ensures that the order of
+ // traversal of functions within the SCC doesn't change the final result.
+
+ DenseMap<Function *, uint64_t> AdditionalCounts;
+ for (auto It = CallSites.begin(); It != Mid; It++) {
+ auto &CS = *It;
+ auto RelFreq = GetCallSiteRelFreq(CS);
+ Function *Callee = CS.getCalledFunction();
+ Function *Caller = CS.getCaller();
+ RelFreq *= Scaled64(GetCount(Caller), 0);
+ uint64_t AdditionalCount = RelFreq.toInt<uint64_t>();
+ AdditionalCounts[Callee] += AdditionalCount;
+ }
+
+ // Update the counts for the functions in the SCC.
+ for (auto &Entry : AdditionalCounts)
+ AddToCount(Entry.first, Entry.second);
+
+ // Now update the counts for functions not in SCC.
+ for (auto It = Mid; It != CallSites.end(); It++) {
+ auto &CS = *It;
+ auto Weight = GetCallSiteRelFreq(CS);
+ Function *Callee = CS.getCalledFunction();
+ Function *Caller = CS.getCaller();
+ Weight *= Scaled64(GetCount(Caller), 0);
+ AddToCount(Callee, Weight.toInt<uint64_t>());
+ }
+}
+
+/// Propgate synthetic entry counts on a callgraph.
+///
+/// This performs a reverse post-order traversal of the callgraph SCC. For each
+/// SCC, it first propagates the entry counts to the functions within the SCC
+/// through call edges and updates them in one shot. Then the entry counts are
+/// propagated to functions outside the SCC.
+void llvm::propagateSyntheticCounts(
+ const CallGraph &CG, function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
+ function_ref<uint64_t(Function *F)> GetCount,
+ function_ref<void(Function *F, uint64_t)> AddToCount) {
+
+ SmallVector<SmallPtrSet<Function *, 8>, 16> SCCs;
+ for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) {
+ auto SCC = *I;
+
+ SmallPtrSet<Function *, 8> SCCFunctions;
+ for (auto *Node : SCC) {
+ Function *F = Node->getFunction();
+ if (F && !F->isDeclaration()) {
+ SCCFunctions.insert(F);
+ }
+ }
+ SCCs.push_back(SCCFunctions);
+ }
+
+ for (auto &SCCFunctions : reverse(SCCs))
+ propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount);
+}
OpenPOWER on IntegriCloud