diff options
Diffstat (limited to 'llvm/lib/CodeGen/SwitchLoweringUtils.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SwitchLoweringUtils.cpp | 94 |
1 files changed, 50 insertions, 44 deletions
diff --git a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp index 83acf7f8071..2b9999d0b41 100644 --- a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp +++ b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp @@ -11,33 +11,47 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallSet.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/SwitchLoweringUtils.h" using namespace llvm; using namespace SwitchCG; -uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters, - unsigned First, unsigned Last) { - assert(Last >= First); - const APInt &LowCase = Clusters[First].Low->getValue(); - const APInt &HighCase = Clusters[Last].High->getValue(); - assert(LowCase.getBitWidth() == HighCase.getBitWidth()); - - // FIXME: A range of consecutive cases has 100% density, but only requires one - // comparison to lower. We should discriminate against such consecutive ranges - // in jump tables. - return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1; -} +// Collection of partition stats, made up of, for a given cluster, +// the range of the cases, their number and the number of unique targets. +struct PartitionStats { + uint64_t Range, Cases, Targets; +}; + +static PartitionStats getJumpTableStats(const CaseClusterVector &Clusters, + unsigned First, unsigned Last, + bool HasReachableDefault) { + assert(Last >= First && "Invalid order of clusters"); + + SmallSet<const MachineBasicBlock *, 8> Targets; + PartitionStats Stats; + + Stats.Cases = 0; + for (unsigned i = First; i <= Last; ++i) { + const APInt &Hi = Clusters[i].High->getValue(), + &Lo = Clusters[i].Low->getValue(); + Stats.Cases += (Hi - Lo).getLimitedValue() + 1; + + Targets.insert(Clusters[i].MBB); + } + assert(Stats.Cases < UINT64_MAX / 100 && "Too many cases"); + + const APInt &Hi = Clusters[Last].High->getValue(), + &Lo = Clusters[First].Low->getValue(); + assert(Hi.getBitWidth() == Lo.getBitWidth()); + Stats.Range = (Hi - Lo).getLimitedValue((UINT64_MAX - 1) / 100) + 1; + assert(Stats.Range >= Stats.Cases && "Invalid range or number of cases"); + + Stats.Targets = + Targets.size() + (HasReachableDefault && Stats.Range > Stats.Cases); -uint64_t -SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, - unsigned First, unsigned Last) { - assert(Last >= First); - assert(TotalCases[Last] >= TotalCases[First]); - uint64_t NumCases = - TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]); - return NumCases; + return Stats; } void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, @@ -64,23 +78,13 @@ void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, if (N < 2 || N < MinJumpTableEntries) return; - // Accumulated number of cases in each cluster and those prior to it. - SmallVector<unsigned, 8> TotalCases(N); - for (unsigned i = 0; i < N; ++i) { - const APInt &Hi = Clusters[i].High->getValue(); - const APInt &Lo = Clusters[i].Low->getValue(); - TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; - if (i != 0) - TotalCases[i] += TotalCases[i - 1]; - } - - uint64_t Range = getJumpTableRange(Clusters,0, N - 1); - uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1); - assert(NumCases < UINT64_MAX / 100); - assert(Range >= NumCases); + const bool HasReachableDefault = + !isa<UnreachableInst>(DefaultMBB->getBasicBlock()->getFirstNonPHIOrDbg()); + PartitionStats Stats = + getJumpTableStats(Clusters, 0, N - 1, HasReachableDefault); // Cheap case: the whole range may be suitable for jump table. - if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) { + if (TLI->isSuitableForJumpTable(SI, Stats.Cases, Stats.Targets, Stats.Range)) { CaseCluster JTCluster; if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { Clusters[0] = JTCluster; @@ -104,9 +108,6 @@ void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, SmallVector<unsigned, 8> MinPartitions(N); // LastElement[i] is the last element of the partition starting at i. SmallVector<unsigned, 8> LastElement(N); - // PartitionsScore[i] is used to break ties when choosing between two - // partitionings resulting in the same number of partitions. - SmallVector<unsigned, 8> PartitionsScore(N); // For PartitionsScore, a small number of comparisons is considered as good as // a jump table and a single comparison is considered better than a jump // table. @@ -116,6 +117,11 @@ void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, FewCases = 1, SingleCase = 2 }; + // PartitionsScore[i] is used to break ties when choosing between two + // partitionings resulting in the same number of partitions. + SmallVector<unsigned, 8> PartitionsScore(N); + // PartitionsStats[j] is the stats for the partition Clusters[i..j]. + SmallVector<PartitionStats, 8> PartitionsStats(N); // Base case: There is only one way to partition Clusters[N-1]. MinPartitions[N - 1] = 1; @@ -129,16 +135,16 @@ void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters, MinPartitions[i] = MinPartitions[i + 1] + 1; LastElement[i] = i; PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase; + for (int64_t j = i + 1; j < N; j++) + PartitionsStats[j] = + getJumpTableStats(Clusters, i, j, HasReachableDefault); // Search for a solution that results in fewer partitions. for (int64_t j = N - 1; j > i; j--) { // Try building a partition from Clusters[i..j]. - Range = getJumpTableRange(Clusters, i, j); - NumCases = getJumpTableNumCases(TotalCases, i, j); - assert(NumCases < UINT64_MAX / 100); - assert(Range >= NumCases); - - if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) { + if (TLI->isSuitableForJumpTable(SI, PartitionsStats[j].Cases, + PartitionsStats[j].Targets, + PartitionsStats[j].Range)) { unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1]; int64_t NumEntries = j - i + 1; |

