diff options
| -rw-r--r-- | llvm/include/llvm/Analysis/SyntheticCountsUtils.h | 15 | ||||
| -rw-r--r-- | llvm/lib/Analysis/SyntheticCountsUtils.cpp | 29 | ||||
| -rw-r--r-- | llvm/lib/LTO/SummaryBasedOptimizations.cpp | 14 | ||||
| -rw-r--r-- | llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp | 35 |
4 files changed, 49 insertions, 44 deletions
diff --git a/llvm/include/llvm/Analysis/SyntheticCountsUtils.h b/llvm/include/llvm/Analysis/SyntheticCountsUtils.h index 87f4a0100b3..db80bef001e 100644 --- a/llvm/include/llvm/Analysis/SyntheticCountsUtils.h +++ b/llvm/include/llvm/Analysis/SyntheticCountsUtils.h @@ -36,16 +36,17 @@ public: using EdgeRef = typename CGT::EdgeRef; using SccTy = std::vector<NodeRef>; - using GetRelBBFreqTy = function_ref<Optional<Scaled64>(EdgeRef)>; - using GetCountTy = function_ref<uint64_t(NodeRef)>; - using AddCountTy = function_ref<void(NodeRef, uint64_t)>; + // Not all EdgeRef have information about the source of the edge. Hence + // NodeRef corresponding to the source of the EdgeRef is explicitly passed. + using GetProfCountTy = function_ref<Optional<Scaled64>(NodeRef, EdgeRef)>; + using AddCountTy = function_ref<void(NodeRef, Scaled64)>; - static void propagate(const CallGraphType &CG, GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, AddCountTy AddCount); + static void propagate(const CallGraphType &CG, GetProfCountTy GetProfCount, + AddCountTy AddCount); private: - static void propagateFromSCC(const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, AddCountTy AddCount); + static void propagateFromSCC(const SccTy &SCC, GetProfCountTy GetProfCount, + AddCountTy AddCount); }; } // namespace llvm diff --git a/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/llvm/lib/Analysis/SyntheticCountsUtils.cpp index 386396bcff3..c2d7bb11a4c 100644 --- a/llvm/lib/Analysis/SyntheticCountsUtils.cpp +++ b/llvm/lib/Analysis/SyntheticCountsUtils.cpp @@ -26,8 +26,7 @@ using namespace llvm; // Given an SCC, propagate entry counts along the edge of the SCC nodes. template <typename CallGraphType> void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( - const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount, - AddCountTy AddCount) { + const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) { DenseSet<NodeRef> SCCNodes; SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges; @@ -54,17 +53,13 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( // This ensures that the order of // traversal of nodes within the SCC doesn't affect the final result. - DenseMap<NodeRef, uint64_t> AdditionalCounts; + DenseMap<NodeRef, Scaled64> AdditionalCounts; for (auto &E : SCCEdges) { - auto OptRelFreq = GetRelBBFreq(E.second); - if (!OptRelFreq) + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) continue; - Scaled64 RelFreq = OptRelFreq.getValue(); - auto Caller = E.first; auto Callee = CGT::edge_dest(E.second); - RelFreq *= Scaled64(GetCount(Caller), 0); - uint64_t AdditionalCount = RelFreq.toInt<uint64_t>(); - AdditionalCounts[Callee] += AdditionalCount; + AdditionalCounts[Callee] += OptProfCount.getValue(); } // Update the counts for the nodes in the SCC. @@ -73,14 +68,11 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( // Now update the counts for nodes outside the SCC. for (auto &E : NonSCCEdges) { - auto OptRelFreq = GetRelBBFreq(E.second); - if (!OptRelFreq) + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) continue; - Scaled64 RelFreq = OptRelFreq.getValue(); - auto Caller = E.first; auto Callee = CGT::edge_dest(E.second); - RelFreq *= Scaled64(GetCount(Caller), 0); - AddCount(Callee, RelFreq.toInt<uint64_t>()); + AddCount(Callee, OptProfCount.getValue()); } } @@ -94,8 +86,7 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( template <typename CallGraphType> void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, - GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, + GetProfCountTy GetProfCount, AddCountTy AddCount) { std::vector<SccTy> SCCs; @@ -107,7 +98,7 @@ void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, // The scc iterator returns the scc in bottom-up order, so reverse the SCCs // and call propagateFromSCC. for (auto &SCC : reverse(SCCs)) - propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount); + propagateFromSCC(SCC, GetProfCount, AddCount); } template class llvm::SyntheticCountsUtils<const CallGraph *>; diff --git a/llvm/lib/LTO/SummaryBasedOptimizations.cpp b/llvm/lib/LTO/SummaryBasedOptimizations.cpp index 8b1abb78462..bcdd984daa5 100644 --- a/llvm/lib/LTO/SummaryBasedOptimizations.cpp +++ b/llvm/lib/LTO/SummaryBasedOptimizations.cpp @@ -60,21 +60,27 @@ void llvm::computeSyntheticCounts(ModuleSummaryIndex &Index) { return UINT64_C(0); } }; - auto AddToEntryCount = [](ValueInfo V, uint64_t New) { + auto AddToEntryCount = [](ValueInfo V, Scaled64 New) { if (!V.getSummaryList().size()) return; for (auto &GVS : V.getSummaryList()) { auto S = GVS.get()->getBaseObject(); auto *F = cast<FunctionSummary>(S); - F->setEntryCount(SaturatingAdd(F->entryCount(), New)); + F->setEntryCount( + SaturatingAdd(F->entryCount(), New.template toInt<uint64_t>())); } }; + auto GetProfileCount = [&](ValueInfo V, FunctionSummary::EdgeTy &Edge) { + auto RelFreq = GetCallSiteRelFreq(Edge); + Scaled64 EC(GetEntryCount(V), 0); + return RelFreq * EC; + }; // After initializing the counts in initializeCounts above, the counts have to // be propagated across the combined callgraph. // SyntheticCountsUtils::propagate takes care of this propagation on any // callgraph that specialized GraphTraits. - SyntheticCountsUtils<ModuleSummaryIndex *>::propagate( - &Index, GetCallSiteRelFreq, GetEntryCount, AddToEntryCount); + SyntheticCountsUtils<ModuleSummaryIndex *>::propagate(&Index, GetProfileCount, + AddToEntryCount); Index.setHasSyntheticEntryCounts(); } diff --git a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index 64837d4f5d6..ba4efb3ff60 100644 --- a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/SyntheticCountsUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Function.h" @@ -98,13 +99,15 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, ModuleAnalysisManager &MAM) { FunctionAnalysisManager &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - DenseMap<Function *, uint64_t> Counts; + DenseMap<Function *, Scaled64> Counts; // Set initial entry counts. - initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; }); + initializeCounts( + M, [&](Function *F, uint64_t Count) { Counts[F] = Scaled64(Count, 0); }); - // Compute the relative block frequency for a call edge. Use scaled numbers - // and not integers since the relative block frequency could be less than 1. - auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) { + // Edge includes information about the source. Hence ignore the first + // parameter. + auto GetCallSiteProfCount = [&](const CallGraphNode *, + const CallGraphNode::CallRecord &Edge) { Optional<Scaled64> Res = None; if (!Edge.first) return Res; @@ -112,29 +115,33 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, CallSite CS(cast<Instruction>(Edge.first)); Function *Caller = CS.getCaller(); auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*Caller); + + // Now compute the callsite count from relative frequency and + // entry count: BasicBlock *CSBB = CS.getInstruction()->getParent(); Scaled64 EntryFreq(BFI.getEntryFreq(), 0); - Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0); - BBFreq /= EntryFreq; - return Optional<Scaled64>(BBFreq); + Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0); + BBCount /= EntryFreq; + BBCount *= Counts[Caller]; + return Optional<Scaled64>(BBCount); }; CallGraph CG(M); // Propgate the entry counts on the callgraph. SyntheticCountsUtils<const CallGraph *>::propagate( - &CG, GetCallSiteRelFreq, - [&](const CallGraphNode *N) { return Counts[N->getFunction()]; }, - [&](const CallGraphNode *N, uint64_t New) { + &CG, GetCallSiteProfCount, [&](const CallGraphNode *N, Scaled64 New) { auto F = N->getFunction(); if (!F || F->isDeclaration()) return; + Counts[F] += New; }); // Set the counts as metadata. - for (auto Entry : Counts) - Entry.first->setEntryCount( - ProfileCount(Entry.second, Function::PCT_Synthetic)); + for (auto Entry : Counts) { + Entry.first->setEntryCount(ProfileCount( + Entry.second.template toInt<uint64_t>(), Function::PCT_Synthetic)); + } return PreservedAnalyses::all(); } |

