diff options
Diffstat (limited to 'llvm/include/llvm/Support/BranchProbability.h')
| -rw-r--r-- | llvm/include/llvm/Support/BranchProbability.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/llvm/include/llvm/Support/BranchProbability.h b/llvm/include/llvm/Support/BranchProbability.h index 98b2a3d7d9a..f204777066e 100644 --- a/llvm/include/llvm/Support/BranchProbability.h +++ b/llvm/include/llvm/Support/BranchProbability.h @@ -15,7 +15,10 @@ #define LLVM_SUPPORT_BRANCHPROBABILITY_H #include "llvm/Support/DataTypes.h" +#include <algorithm> #include <cassert> +#include <climits> +#include <numeric> namespace llvm { @@ -53,6 +56,11 @@ public: template <class ProbabilityList> static void normalizeProbabilities(ProbabilityList &Probs); + // Normalize a list of weights by scaling them down so that the sum of them + // doesn't exceed UINT32_MAX. + template <class WeightListIter> + static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End); + uint32_t getNumerator() const { return N; } static uint32_t getDenominator() { return D; } @@ -135,6 +143,49 @@ void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) { Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum; } +template <class WeightListIter> +void BranchProbability::normalizeEdgeWeights(WeightListIter Begin, + WeightListIter End) { + // First we compute the sum with 64-bits of precision. + uint64_t Sum = std::accumulate(Begin, End, uint64_t(0)); + + if (Sum > UINT32_MAX) { + // Compute the scale necessary to cause the weights to fit, and re-sum with + // that scale applied. + assert(Sum / UINT32_MAX < UINT32_MAX && + "The sum of weights exceeds UINT32_MAX^2!"); + uint32_t Scale = Sum / UINT32_MAX + 1; + for (auto I = Begin; I != End; ++I) + *I /= Scale; + Sum = std::accumulate(Begin, End, uint64_t(0)); + } + + // Eliminate zero weights. + auto ZeroWeightNum = std::count(Begin, End, 0u); + if (ZeroWeightNum > 0) { + // If all weights are zeros, replace them by 1. + if (Sum == 0) + std::fill(Begin, End, 1u); + else { + // We are converting zeros into ones, and here we need to make sure that + // after this the sum won't exceed UINT32_MAX. + if (Sum + ZeroWeightNum > UINT32_MAX) { + for (auto I = Begin; I != End; ++I) + *I /= 2; + ZeroWeightNum = std::count(Begin, End, 0u); + Sum = std::accumulate(Begin, End, uint64_t(0)); + } + // Scale up non-zero weights and turn zero weights into ones. + uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum; + assert(ScalingFactor >= 1); + if (ScalingFactor > 1) + for (auto I = Begin; I != End; ++I) + *I *= ScalingFactor; + std::replace(Begin, End, 0u, 1u); + } + } +} + } #endif |

