summaryrefslogtreecommitdiffstats
path: root/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SwitchLoweringUtils.cpp')
-rw-r--r--llvm/lib/CodeGen/SwitchLoweringUtils.cpp486
1 files changed, 486 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
new file mode 100644
index 00000000000..1909ce116dd
--- /dev/null
+++ b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
@@ -0,0 +1,486 @@
+//===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains switch inst lowering optimizations and utilities for
+// codegen, so that it can be used for both SelectionDAG and GlobalISel.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/SwitchLoweringUtils.h"
+
+using namespace llvm;
+using namespace SwitchCG;
+
+uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
+ unsigned First, unsigned Last) {
+ assert(Last >= First);
+ const APInt &LowCase = Clusters[First].Low->getValue();
+ const APInt &HighCase = Clusters[Last].High->getValue();
+ assert(LowCase.getBitWidth() == HighCase.getBitWidth());
+
+ // FIXME: A range of consecutive cases has 100% density, but only requires one
+ // comparison to lower. We should discriminate against such consecutive ranges
+ // in jump tables.
+
+ return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
+}
+
+uint64_t
+SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
+ unsigned First, unsigned Last) {
+ assert(Last >= First);
+ assert(TotalCases[Last] >= TotalCases[First]);
+ uint64_t NumCases =
+ TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
+ return NumCases;
+}
+
+void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
+ const SwitchInst *SI,
+ MachineBasicBlock *DefaultMBB) {
+#ifndef NDEBUG
+ // Clusters must be non-empty, sorted, and only contain Range clusters.
+ assert(!Clusters.empty());
+ for (CaseCluster &C : Clusters)
+ assert(C.Kind == CC_Range);
+ for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
+ assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
+#endif
+
+ if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
+ return;
+
+ const int64_t N = Clusters.size();
+ const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
+ const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
+
+ if (N < 2 || N < MinJumpTableEntries)
+ return;
+
+ // TotalCases[i]: Total nbr of cases in Clusters[0..i].
+ SmallVector<unsigned, 8> TotalCases(N);
+ for (unsigned i = 0; i < N; ++i) {
+ const APInt &Hi = Clusters[i].High->getValue();
+ const APInt &Lo = Clusters[i].Low->getValue();
+ TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
+ if (i != 0)
+ TotalCases[i] += TotalCases[i - 1];
+ }
+
+ // Cheap case: the whole range may be suitable for jump table.
+ uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
+ uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
+ assert(NumCases < UINT64_MAX / 100);
+ assert(Range >= NumCases);
+ if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
+ CaseCluster JTCluster;
+ if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
+ Clusters[0] = JTCluster;
+ Clusters.resize(1);
+ return;
+ }
+ }
+
+ // The algorithm below is not suitable for -O0.
+ if (TM->getOptLevel() == CodeGenOpt::None)
+ return;
+
+ // Split Clusters into minimum number of dense partitions. The algorithm uses
+ // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
+ // for the Case Statement'" (1994), but builds the MinPartitions array in
+ // reverse order to make it easier to reconstruct the partitions in ascending
+ // order. In the choice between two optimal partitionings, it picks the one
+ // which yields more jump tables.
+
+ // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
+ SmallVector<unsigned, 8> MinPartitions(N);
+ // LastElement[i] is the last element of the partition starting at i.
+ SmallVector<unsigned, 8> LastElement(N);
+ // PartitionsScore[i] is used to break ties when choosing between two
+ // partitionings resulting in the same number of partitions.
+ SmallVector<unsigned, 8> PartitionsScore(N);
+ // For PartitionsScore, a small number of comparisons is considered as good as
+ // a jump table and a single comparison is considered better than a jump
+ // table.
+ enum PartitionScores : unsigned {
+ NoTable = 0,
+ Table = 1,
+ FewCases = 1,
+ SingleCase = 2
+ };
+
+ // Base case: There is only one way to partition Clusters[N-1].
+ MinPartitions[N - 1] = 1;
+ LastElement[N - 1] = N - 1;
+ PartitionsScore[N - 1] = PartitionScores::SingleCase;
+
+ // Note: loop indexes are signed to avoid underflow.
+ for (int64_t i = N - 2; i >= 0; i--) {
+ // Find optimal partitioning of Clusters[i..N-1].
+ // Baseline: Put Clusters[i] into a partition on its own.
+ MinPartitions[i] = MinPartitions[i + 1] + 1;
+ LastElement[i] = i;
+ PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
+
+ // Search for a solution that results in fewer partitions.
+ for (int64_t j = N - 1; j > i; j--) {
+ // Try building a partition from Clusters[i..j].
+ uint64_t Range = getJumpTableRange(Clusters, i, j);
+ uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
+ assert(NumCases < UINT64_MAX / 100);
+ assert(Range >= NumCases);
+ if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
+ unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
+ unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
+ int64_t NumEntries = j - i + 1;
+
+ if (NumEntries == 1)
+ Score += PartitionScores::SingleCase;
+ else if (NumEntries <= SmallNumberOfEntries)
+ Score += PartitionScores::FewCases;
+ else if (NumEntries >= MinJumpTableEntries)
+ Score += PartitionScores::Table;
+
+ // If this leads to fewer partitions, or to the same number of
+ // partitions with better score, it is a better partitioning.
+ if (NumPartitions < MinPartitions[i] ||
+ (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
+ MinPartitions[i] = NumPartitions;
+ LastElement[i] = j;
+ PartitionsScore[i] = Score;
+ }
+ }
+ }
+ }
+
+ // Iterate over the partitions, replacing some with jump tables in-place.
+ unsigned DstIndex = 0;
+ for (unsigned First = 0, Last; First < N; First = Last + 1) {
+ Last = LastElement[First];
+ assert(Last >= First);
+ assert(DstIndex <= First);
+ unsigned NumClusters = Last - First + 1;
+
+ CaseCluster JTCluster;
+ if (NumClusters >= MinJumpTableEntries &&
+ buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
+ Clusters[DstIndex++] = JTCluster;
+ } else {
+ for (unsigned I = First; I <= Last; ++I)
+ std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
+ }
+ }
+ Clusters.resize(DstIndex);
+}
+
+bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
+ unsigned First, unsigned Last,
+ const SwitchInst *SI,
+ MachineBasicBlock *DefaultMBB,
+ CaseCluster &JTCluster) {
+ assert(First <= Last);
+
+ auto Prob = BranchProbability::getZero();
+ unsigned NumCmps = 0;
+ std::vector<MachineBasicBlock*> Table;
+ DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
+
+ // Initialize probabilities in JTProbs.
+ for (unsigned I = First; I <= Last; ++I)
+ JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
+
+ for (unsigned I = First; I <= Last; ++I) {
+ assert(Clusters[I].Kind == CC_Range);
+ Prob += Clusters[I].Prob;
+ const APInt &Low = Clusters[I].Low->getValue();
+ const APInt &High = Clusters[I].High->getValue();
+ NumCmps += (Low == High) ? 1 : 2;
+ if (I != First) {
+ // Fill the gap between this and the previous cluster.
+ const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
+ assert(PreviousHigh.slt(Low));
+ uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
+ for (uint64_t J = 0; J < Gap; J++)
+ Table.push_back(DefaultMBB);
+ }
+ uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
+ for (uint64_t J = 0; J < ClusterSize; ++J)
+ Table.push_back(Clusters[I].MBB);
+ JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
+ }
+
+ unsigned NumDests = JTProbs.size();
+ if (TLI->isSuitableForBitTests(NumDests, NumCmps,
+ Clusters[First].Low->getValue(),
+ Clusters[Last].High->getValue(), *DL)) {
+ // Clusters[First..Last] should be lowered as bit tests instead.
+ return false;
+ }
+
+ // Create the MBB that will load from and jump through the table.
+ // Note: We create it here, but it's not inserted into the function yet.
+ MachineFunction *CurMF = FuncInfo.MF;
+ MachineBasicBlock *JumpTableMBB =
+ CurMF->CreateMachineBasicBlock(SI->getParent());
+
+ // Add successors. Note: use table order for determinism.
+ SmallPtrSet<MachineBasicBlock *, 8> Done;
+ for (MachineBasicBlock *Succ : Table) {
+ if (Done.count(Succ))
+ continue;
+ addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
+ Done.insert(Succ);
+ }
+ JumpTableMBB->normalizeSuccProbs();
+
+ unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
+ ->createJumpTableIndex(Table);
+
+ // Set up the jump table info.
+ JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
+ JumpTableHeader JTH(Clusters[First].Low->getValue(),
+ Clusters[Last].High->getValue(), SI->getCondition(),
+ nullptr, false);
+ JTCases.emplace_back(std::move(JTH), std::move(JT));
+
+ JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
+ JTCases.size() - 1, Prob);
+ return true;
+}
+
+void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
+ const SwitchInst *SI) {
+ // Partition Clusters into as few subsets as possible, where each subset has a
+ // range that fits in a machine word and has <= 3 unique destinations.
+
+#ifndef NDEBUG
+ // Clusters must be sorted and contain Range or JumpTable clusters.
+ assert(!Clusters.empty());
+ assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
+ for (const CaseCluster &C : Clusters)
+ assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
+ for (unsigned i = 1; i < Clusters.size(); ++i)
+ assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
+#endif
+
+ // The algorithm below is not suitable for -O0.
+ if (TM->getOptLevel() == CodeGenOpt::None)
+ return;
+
+ // If target does not have legal shift left, do not emit bit tests at all.
+ EVT PTy = TLI->getPointerTy(*DL);
+ if (!TLI->isOperationLegal(ISD::SHL, PTy))
+ return;
+
+ int BitWidth = PTy.getSizeInBits();
+ const int64_t N = Clusters.size();
+
+ // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
+ SmallVector<unsigned, 8> MinPartitions(N);
+ // LastElement[i] is the last element of the partition starting at i.
+ SmallVector<unsigned, 8> LastElement(N);
+
+ // FIXME: This might not be the best algorithm for finding bit test clusters.
+
+ // Base case: There is only one way to partition Clusters[N-1].
+ MinPartitions[N - 1] = 1;
+ LastElement[N - 1] = N - 1;
+
+ // Note: loop indexes are signed to avoid underflow.
+ for (int64_t i = N - 2; i >= 0; --i) {
+ // Find optimal partitioning of Clusters[i..N-1].
+ // Baseline: Put Clusters[i] into a partition on its own.
+ MinPartitions[i] = MinPartitions[i + 1] + 1;
+ LastElement[i] = i;
+
+ // Search for a solution that results in fewer partitions.
+ // Note: the search is limited by BitWidth, reducing time complexity.
+ for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
+ // Try building a partition from Clusters[i..j].
+
+ // Check the range.
+ if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
+ Clusters[j].High->getValue(), *DL))
+ continue;
+
+ // Check nbr of destinations and cluster types.
+ // FIXME: This works, but doesn't seem very efficient.
+ bool RangesOnly = true;
+ BitVector Dests(FuncInfo.MF->getNumBlockIDs());
+ for (int64_t k = i; k <= j; k++) {
+ if (Clusters[k].Kind != CC_Range) {
+ RangesOnly = false;
+ break;
+ }
+ Dests.set(Clusters[k].MBB->getNumber());
+ }
+ if (!RangesOnly || Dests.count() > 3)
+ break;
+
+ // Check if it's a better partition.
+ unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
+ if (NumPartitions < MinPartitions[i]) {
+ // Found a better partition.
+ MinPartitions[i] = NumPartitions;
+ LastElement[i] = j;
+ }
+ }
+ }
+
+ // Iterate over the partitions, replacing with bit-test clusters in-place.
+ unsigned DstIndex = 0;
+ for (unsigned First = 0, Last; First < N; First = Last + 1) {
+ Last = LastElement[First];
+ assert(First <= Last);
+ assert(DstIndex <= First);
+
+ CaseCluster BitTestCluster;
+ if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
+ Clusters[DstIndex++] = BitTestCluster;
+ } else {
+ size_t NumClusters = Last - First + 1;
+ std::memmove(&Clusters[DstIndex], &Clusters[First],
+ sizeof(Clusters[0]) * NumClusters);
+ DstIndex += NumClusters;
+ }
+ }
+ Clusters.resize(DstIndex);
+}
+
+bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
+ unsigned First, unsigned Last,
+ const SwitchInst *SI,
+ CaseCluster &BTCluster) {
+ assert(First <= Last);
+ if (First == Last)
+ return false;
+
+ BitVector Dests(FuncInfo.MF->getNumBlockIDs());
+ unsigned NumCmps = 0;
+ for (int64_t I = First; I <= Last; ++I) {
+ assert(Clusters[I].Kind == CC_Range);
+ Dests.set(Clusters[I].MBB->getNumber());
+ NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
+ }
+ unsigned NumDests = Dests.count();
+
+ APInt Low = Clusters[First].Low->getValue();
+ APInt High = Clusters[Last].High->getValue();
+ assert(Low.slt(High));
+
+ if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
+ return false;
+
+ APInt LowBound;
+ APInt CmpRange;
+
+ const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
+ assert(TLI->rangeFitsInWord(Low, High, *DL) &&
+ "Case range must fit in bit mask!");
+
+ // Check if the clusters cover a contiguous range such that no value in the
+ // range will jump to the default statement.
+ bool ContiguousRange = true;
+ for (int64_t I = First + 1; I <= Last; ++I) {
+ if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
+ ContiguousRange = false;
+ break;
+ }
+ }
+
+ if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
+ // Optimize the case where all the case values fit in a word without having
+ // to subtract minValue. In this case, we can optimize away the subtraction.
+ LowBound = APInt::getNullValue(Low.getBitWidth());
+ CmpRange = High;
+ ContiguousRange = false;
+ } else {
+ LowBound = Low;
+ CmpRange = High - Low;
+ }
+
+ CaseBitsVector CBV;
+ auto TotalProb = BranchProbability::getZero();
+ for (unsigned i = First; i <= Last; ++i) {
+ // Find the CaseBits for this destination.
+ unsigned j;
+ for (j = 0; j < CBV.size(); ++j)
+ if (CBV[j].BB == Clusters[i].MBB)
+ break;
+ if (j == CBV.size())
+ CBV.push_back(
+ CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
+ CaseBits *CB = &CBV[j];
+
+ // Update Mask, Bits and ExtraProb.
+ uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
+ uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
+ assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
+ CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
+ CB->Bits += Hi - Lo + 1;
+ CB->ExtraProb += Clusters[i].Prob;
+ TotalProb += Clusters[i].Prob;
+ }
+
+ BitTestInfo BTI;
+ llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
+ // Sort by probability first, number of bits second, bit mask third.
+ if (a.ExtraProb != b.ExtraProb)
+ return a.ExtraProb > b.ExtraProb;
+ if (a.Bits != b.Bits)
+ return a.Bits > b.Bits;
+ return a.Mask < b.Mask;
+ });
+
+ for (auto &CB : CBV) {
+ MachineBasicBlock *BitTestBB =
+ FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
+ BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
+ }
+ BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
+ SI->getCondition(), -1U, MVT::Other, false,
+ ContiguousRange, nullptr, nullptr, std::move(BTI),
+ TotalProb);
+
+ BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
+ BitTestCases.size() - 1, TotalProb);
+ return true;
+}
+
+void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
+#ifndef NDEBUG
+ for (const CaseCluster &CC : Clusters)
+ assert(CC.Low == CC.High && "Input clusters must be single-case");
+#endif
+
+ llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
+ return a.Low->getValue().slt(b.Low->getValue());
+ });
+
+ // Merge adjacent clusters with the same destination.
+ const unsigned N = Clusters.size();
+ unsigned DstIndex = 0;
+ for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
+ CaseCluster &CC = Clusters[SrcIndex];
+ const ConstantInt *CaseVal = CC.Low;
+ MachineBasicBlock *Succ = CC.MBB;
+
+ if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
+ (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
+ // If this case has the same successor and is a neighbour, merge it into
+ // the previous cluster.
+ Clusters[DstIndex - 1].High = CaseVal;
+ Clusters[DstIndex - 1].Prob += CC.Prob;
+ } else {
+ std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
+ sizeof(Clusters[SrcIndex]));
+ }
+ }
+ Clusters.resize(DstIndex);
+}
OpenPOWER on IntegriCloud