diff options
author | Hans Wennborg <hans@hanshq.net> | 2015-04-22 23:14:56 +0000 |
---|---|---|
committer | Hans Wennborg <hans@hanshq.net> | 2015-04-22 23:14:56 +0000 |
commit | 15823d49b61fc7ddafc2cb10b12f4498d53fb216 (patch) | |
tree | 16b037a114c50361f80741b97a09f48dff8cd092 /llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | |
parent | 0405d68bb472f208ef660ad69682fe1b54a58804 (diff) | |
download | bcm5719-llvm-15823d49b61fc7ddafc2cb10b12f4498d53fb216.tar.gz bcm5719-llvm-15823d49b61fc7ddafc2cb10b12f4498d53fb216.zip |
Switch lowering: extract jump tables and bit tests before building binary tree (PR22262)
This is a re-commit of r235101, which also fixes the problems with the previous patch:
- Switches with only a default case and non-fallthrough were handled incorrectly
- The previous patch tickled a bug in PowerPC Early-Return Creation which is fixed here.
> This is a major rewrite of the SelectionDAG switch lowering. The previous code
> would lower switches as a binary tre, discovering clusters of cases
> suitable for lowering by jump tables or bit tests as it went along. To increase
> the likelihood of finding jump tables, the binary tree pivot was selected to
> maximize case density on both sides of the pivot.
>
> By not selecting the pivot in the middle, the binary trees would not always
> be balanced, leading to performance problems in the generated code.
>
> This patch rewrites the lowering to search for clusters of cases
> suitable for jump tables or bit tests first, and then builds the binary
> tree around those clusters. This way, the binary tree will always be balanced.
>
> This has the added benefit of decoupling the different aspects of the lowering:
> tree building and jump table or bit tests finding are now easier to tweak
> separately.
>
> For example, this will enable us to balance the tree based on profile info
> in the future.
>
> The algorithm for finding jump tables is quadratic, whereas the previous algorithm
> was O(n log n) for common cases, and quadratic only in the worst-case. This
> doesn't seem to be major problem in practice, e.g. compiling a file consisting
> of a 10k-case switch was only 30% slower, and such large switches should be rare
> in practice. Compiling e.g. gcc.c showed no compile-time difference. If this
> does turn out to be a problem, we could limit the search space of the algorithm.
>
> This commit also disables all optimizations during switch lowering in -O0.
>
> Differential Revision: http://reviews.llvm.org/D8649
llvm-svn: 235560
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 1443 |
1 files changed, 780 insertions, 663 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 2c813138f84..bc7d8fdb7ac 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1928,7 +1928,7 @@ void SelectionDAGBuilder::visitBitTestHeader(BitTestBlock &B, // Avoid emitting unnecessary branches to the next block. if (MBB != NextBlock(SwitchBB)) - BrRange = DAG.getNode(ISD::BR, getCurSDLoc(), MVT::Other, CopyTo, + BrRange = DAG.getNode(ISD::BR, getCurSDLoc(), MVT::Other, BrRange, DAG.getBasicBlock(MBB)); DAG.setRoot(BrRange); @@ -2101,592 +2101,37 @@ SelectionDAGBuilder::visitLandingPadClauseBB(GlobalValue *ClauseGV, return VReg; } -/// handleSmallSwitchCaseRange - Emit a series of specific tests (suitable for -/// small case ranges). -bool SelectionDAGBuilder::handleSmallSwitchRange(CaseRec& CR, - CaseRecVector& WorkList, - const Value* SV, - MachineBasicBlock *Default, - MachineBasicBlock *SwitchBB) { - // Size is the number of Cases represented by this range. - size_t Size = CR.Range.second - CR.Range.first; - if (Size > 3) - return false; - - // Get the MachineFunction which holds the current MBB. This is used when - // inserting any additional MBBs necessary to represent the switch. - MachineFunction *CurMF = FuncInfo.MF; - - // Figure out which block is immediately after the current one. - MachineBasicBlock *NextMBB = nullptr; - MachineFunction::iterator BBI = CR.CaseBB; - if (++BBI != FuncInfo.MF->end()) - NextMBB = BBI; - - BranchProbabilityInfo *BPI = FuncInfo.BPI; - // If any two of the cases has the same destination, and if one value - // is the same as the other, but has one bit unset that the other has set, - // use bit manipulation to do two compares at once. For example: - // "if (X == 6 || X == 4)" -> "if ((X|2) == 6)" - // TODO: This could be extended to merge any 2 cases in switches with 3 cases. - // TODO: Handle cases where CR.CaseBB != SwitchBB. - if (Size == 2 && CR.CaseBB == SwitchBB) { - Case &Small = *CR.Range.first; - Case &Big = *(CR.Range.second-1); - - if (Small.Low == Small.High && Big.Low == Big.High && Small.BB == Big.BB) { - const APInt& SmallValue = Small.Low->getValue(); - const APInt& BigValue = Big.Low->getValue(); - - // Check that there is only one bit different. - if (BigValue.countPopulation() == SmallValue.countPopulation() + 1 && - (SmallValue | BigValue) == BigValue) { - // Isolate the common bit. - APInt CommonBit = BigValue & ~SmallValue; - assert((SmallValue | CommonBit) == BigValue && - CommonBit.countPopulation() == 1 && "Not a common bit?"); - - SDValue CondLHS = getValue(SV); - EVT VT = CondLHS.getValueType(); - SDLoc DL = getCurSDLoc(); - - SDValue Or = DAG.getNode(ISD::OR, DL, VT, CondLHS, - DAG.getConstant(CommonBit, VT)); - SDValue Cond = DAG.getSetCC(DL, MVT::i1, - Or, DAG.getConstant(BigValue, VT), - ISD::SETEQ); - - // Update successor info. - // Both Small and Big will jump to Small.BB, so we sum up the weights. - addSuccessorWithWeight(SwitchBB, Small.BB, - Small.ExtraWeight + Big.ExtraWeight); - addSuccessorWithWeight(SwitchBB, Default, - // The default destination is the first successor in IR. - BPI ? BPI->getEdgeWeight(SwitchBB->getBasicBlock(), (unsigned)0) : 0); - - // Insert the true branch. - SDValue BrCond = DAG.getNode(ISD::BRCOND, DL, MVT::Other, - getControlRoot(), Cond, - DAG.getBasicBlock(Small.BB)); - - // Insert the false branch. - BrCond = DAG.getNode(ISD::BR, DL, MVT::Other, BrCond, - DAG.getBasicBlock(Default)); - - DAG.setRoot(BrCond); - return true; - } - } - } - - // Order cases by weight so the most likely case will be checked first. - uint32_t UnhandledWeights = 0; - if (BPI) { - for (CaseItr I = CR.Range.first, IE = CR.Range.second; I != IE; ++I) { - uint32_t IWeight = I->ExtraWeight; - UnhandledWeights += IWeight; - for (CaseItr J = CR.Range.first; J < I; ++J) { - uint32_t JWeight = J->ExtraWeight; - if (IWeight > JWeight) - std::swap(*I, *J); - } - } - } - // Rearrange the case blocks so that the last one falls through if possible. - Case &BackCase = *(CR.Range.second-1); - if (Size > 1 && NextMBB && Default != NextMBB && BackCase.BB != NextMBB) { - // The last case block won't fall through into 'NextMBB' if we emit the - // branches in this order. See if rearranging a case value would help. - // We start at the bottom as it's the case with the least weight. - for (Case *I = &*(CR.Range.second-2), *E = &*CR.Range.first-1; I != E; --I) - if (I->BB == NextMBB) { - std::swap(*I, BackCase); - break; - } - } - - // Create a CaseBlock record representing a conditional branch to - // the Case's target mbb if the value being switched on SV is equal - // to C. - MachineBasicBlock *CurBlock = CR.CaseBB; - for (CaseItr I = CR.Range.first, E = CR.Range.second; I != E; ++I) { - MachineBasicBlock *FallThrough; - if (I != E-1) { - FallThrough = CurMF->CreateMachineBasicBlock(CurBlock->getBasicBlock()); - CurMF->insert(BBI, FallThrough); - - // Put SV in a virtual register to make it available from the new blocks. - ExportFromCurrentBlock(SV); - } else { - // If the last case doesn't match, go to the default block. - FallThrough = Default; - } - - const Value *RHS, *LHS, *MHS; - ISD::CondCode CC; - if (I->High == I->Low) { - // This is just small small case range :) containing exactly 1 case - CC = ISD::SETEQ; - LHS = SV; RHS = I->High; MHS = nullptr; - } else { - CC = ISD::SETLE; - LHS = I->Low; MHS = SV; RHS = I->High; - } - - // The false weight should be sum of all un-handled cases. - UnhandledWeights -= I->ExtraWeight; - CaseBlock CB(CC, LHS, RHS, MHS, /* truebb */ I->BB, /* falsebb */ FallThrough, - /* me */ CurBlock, - /* trueweight */ I->ExtraWeight, - /* falseweight */ UnhandledWeights); - - // If emitting the first comparison, just call visitSwitchCase to emit the - // code into the current block. Otherwise, push the CaseBlock onto the - // vector to be later processed by SDISel, and insert the node's MBB - // before the next MBB. - if (CurBlock == SwitchBB) - visitSwitchCase(CB, SwitchBB); - else - SwitchCases.push_back(CB); - - CurBlock = FallThrough; - } - - return true; -} - -static inline bool areJTsAllowed(const TargetLowering &TLI) { - return TLI.isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) || - TLI.isOperationLegalOrCustom(ISD::BRIND, MVT::Other); -} - -static APInt ComputeRange(const APInt &First, const APInt &Last) { - uint32_t BitWidth = std::max(Last.getBitWidth(), First.getBitWidth()) + 1; - APInt LastExt = Last.sext(BitWidth), FirstExt = First.sext(BitWidth); - return (LastExt - FirstExt + 1ULL); -} - -/// handleJTSwitchCase - Emit jumptable for current switch case range -bool SelectionDAGBuilder::handleJTSwitchCase(CaseRec &CR, - CaseRecVector &WorkList, - const Value *SV, - MachineBasicBlock *Default, - MachineBasicBlock *SwitchBB) { - Case& FrontCase = *CR.Range.first; - Case& BackCase = *(CR.Range.second-1); - - const APInt &First = FrontCase.Low->getValue(); - const APInt &Last = BackCase.High->getValue(); - - APInt TSize(First.getBitWidth(), 0); - for (CaseItr I = CR.Range.first, E = CR.Range.second; I != E; ++I) - TSize += I->size(); - - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!areJTsAllowed(TLI) || TSize.ult(TLI.getMinimumJumpTableEntries())) - return false; - - APInt Range = ComputeRange(First, Last); - // The density is TSize / Range. Require at least 40%. - // It should not be possible for IntTSize to saturate for sane code, but make - // sure we handle Range saturation correctly. - uint64_t IntRange = Range.getLimitedValue(UINT64_MAX/10); - uint64_t IntTSize = TSize.getLimitedValue(UINT64_MAX/10); - if (IntTSize * 10 < IntRange * 4) - return false; - - DEBUG(dbgs() << "Lowering jump table\n" - << "First entry: " << First << ". Last entry: " << Last << '\n' - << "Range: " << Range << ". Size: " << TSize << ".\n\n"); - - // Get the MachineFunction which holds the current MBB. This is used when - // inserting any additional MBBs necessary to represent the switch. - MachineFunction *CurMF = FuncInfo.MF; - - // Figure out which block is immediately after the current one. - MachineFunction::iterator BBI = CR.CaseBB; - ++BBI; - - const BasicBlock *LLVMBB = CR.CaseBB->getBasicBlock(); - - // Create a new basic block to hold the code for loading the address - // of the jump table, and jumping to it. Update successor information; - // we will either branch to the default case for the switch, or the jump - // table. - MachineBasicBlock *JumpTableBB = CurMF->CreateMachineBasicBlock(LLVMBB); - CurMF->insert(BBI, JumpTableBB); - - addSuccessorWithWeight(CR.CaseBB, Default); - addSuccessorWithWeight(CR.CaseBB, JumpTableBB); - - // Build a vector of destination BBs, corresponding to each target - // of the jump table. If the value of the jump table slot corresponds to - // a case statement, push the case's BB onto the vector, otherwise, push - // the default BB. - std::vector<MachineBasicBlock*> DestBBs; - APInt TEI = First; - for (CaseItr I = CR.Range.first, E = CR.Range.second; I != E; ++TEI) { - const APInt &Low = I->Low->getValue(); - const APInt &High = I->High->getValue(); - - if (Low.sle(TEI) && TEI.sle(High)) { - DestBBs.push_back(I->BB); - if (TEI==High) - ++I; - } else { - DestBBs.push_back(Default); - } - } - - // Calculate weight for each unique destination in CR. - DenseMap<MachineBasicBlock*, uint32_t> DestWeights; - if (FuncInfo.BPI) { - for (CaseItr I = CR.Range.first, E = CR.Range.second; I != E; ++I) - DestWeights[I->BB] += I->ExtraWeight; - } - - // Update successor info. Add one edge to each unique successor. - BitVector SuccsHandled(CR.CaseBB->getParent()->getNumBlockIDs()); - for (MachineBasicBlock *DestBB : DestBBs) { - if (!SuccsHandled[DestBB->getNumber()]) { - SuccsHandled[DestBB->getNumber()] = true; - auto I = DestWeights.find(DestBB); - addSuccessorWithWeight(JumpTableBB, DestBB, - I != DestWeights.end() ? I->second : 0); - } - } - - // Create a jump table index for this jump table. - unsigned JTEncoding = TLI.getJumpTableEncoding(); - unsigned JTI = CurMF->getOrCreateJumpTableInfo(JTEncoding) - ->createJumpTableIndex(DestBBs); - - // Set the jump table information so that we can codegen it as a second - // MachineBasicBlock - JumpTable JT(-1U, JTI, JumpTableBB, Default); - JumpTableHeader JTH(First, Last, SV, CR.CaseBB, (CR.CaseBB == SwitchBB)); - if (CR.CaseBB == SwitchBB) - visitJumpTableHeader(JT, JTH, SwitchBB); - - JTCases.push_back(JumpTableBlock(JTH, JT)); - return true; -} - -/// handleBTSplitSwitchCase - emit comparison and split binary search tree into -/// 2 subtrees. -bool SelectionDAGBuilder::handleBTSplitSwitchCase(CaseRec& CR, - CaseRecVector& WorkList, - const Value* SV, - MachineBasicBlock* SwitchBB) { - Case& FrontCase = *CR.Range.first; - Case& BackCase = *(CR.Range.second-1); - - // Size is the number of Cases represented by this range. - unsigned Size = CR.Range.second - CR.Range.first; - - const APInt &First = FrontCase.Low->getValue(); - const APInt &Last = BackCase.High->getValue(); - double FMetric = 0; - CaseItr Pivot = CR.Range.first + Size/2; - - // Select optimal pivot, maximizing sum density of LHS and RHS. This will - // (heuristically) allow us to emit JumpTable's later. - APInt TSize(First.getBitWidth(), 0); - for (CaseItr I = CR.Range.first, E = CR.Range.second; - I!=E; ++I) - TSize += I->size(); - - APInt LSize = FrontCase.size(); - APInt RSize = TSize-LSize; - DEBUG(dbgs() << "Selecting best pivot: \n" - << "First: " << First << ", Last: " << Last <<'\n' - << "LSize: " << LSize << ", RSize: " << RSize << '\n'); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - for (CaseItr I = CR.Range.first, J=I+1, E = CR.Range.second; - J!=E; ++I, ++J) { - const APInt &LEnd = I->High->getValue(); - const APInt &RBegin = J->Low->getValue(); - APInt Range = ComputeRange(LEnd, RBegin); - assert((Range - 2ULL).isNonNegative() && - "Invalid case distance"); - // Use volatile double here to avoid excess precision issues on some hosts, - // e.g. that use 80-bit X87 registers. - // Only consider the density of sub-ranges that actually have sufficient - // entries to be lowered as a jump table. - volatile double LDensity = - LSize.ult(TLI.getMinimumJumpTableEntries()) - ? 0.0 - : LSize.roundToDouble() / (LEnd - First + 1ULL).roundToDouble(); - volatile double RDensity = - RSize.ult(TLI.getMinimumJumpTableEntries()) - ? 0.0 - : RSize.roundToDouble() / (Last - RBegin + 1ULL).roundToDouble(); - volatile double Metric = Range.logBase2() * (LDensity + RDensity); - // Should always split in some non-trivial place - DEBUG(dbgs() <<"=>Step\n" - << "LEnd: " << LEnd << ", RBegin: " << RBegin << '\n' - << "LDensity: " << LDensity - << ", RDensity: " << RDensity << '\n' - << "Metric: " << Metric << '\n'); - if (FMetric < Metric) { - Pivot = J; - FMetric = Metric; - DEBUG(dbgs() << "Current metric set to: " << FMetric << '\n'); - } - - LSize += J->size(); - RSize -= J->size(); - } - - if (FMetric == 0 || !areJTsAllowed(TLI)) - Pivot = CR.Range.first + Size/2; - splitSwitchCase(CR, Pivot, WorkList, SV, SwitchBB); - return true; -} - -void SelectionDAGBuilder::splitSwitchCase(CaseRec &CR, CaseItr Pivot, - CaseRecVector &WorkList, - const Value *SV, - MachineBasicBlock *SwitchBB) { - // Get the MachineFunction which holds the current MBB. This is used when - // inserting any additional MBBs necessary to represent the switch. - MachineFunction *CurMF = FuncInfo.MF; - - // Figure out which block is immediately after the current one. - MachineFunction::iterator BBI = CR.CaseBB; - ++BBI; - - const BasicBlock *LLVMBB = CR.CaseBB->getBasicBlock(); - - CaseRange LHSR(CR.Range.first, Pivot); - CaseRange RHSR(Pivot, CR.Range.second); - const ConstantInt *C = Pivot->Low; - MachineBasicBlock *FalseBB = nullptr, *TrueBB = nullptr; - - // We know that we branch to the LHS if the Value being switched on is - // less than the Pivot value, C. We use this to optimize our binary - // tree a bit, by recognizing that if SV is greater than or equal to the - // LHS's Case Value, and that Case Value is exactly one less than the - // Pivot's Value, then we can branch directly to the LHS's Target, - // rather than creating a leaf node for it. - if ((LHSR.second - LHSR.first) == 1 && LHSR.first->High == CR.GE && - C->getValue() == (CR.GE->getValue() + 1LL)) { - TrueBB = LHSR.first->BB; - } else { - TrueBB = CurMF->CreateMachineBasicBlock(LLVMBB); - CurMF->insert(BBI, TrueBB); - WorkList.push_back(CaseRec(TrueBB, C, CR.GE, LHSR)); - - // Put SV in a virtual register to make it available from the new blocks. - ExportFromCurrentBlock(SV); - } - - // Similar to the optimization above, if the Value being switched on is - // known to be less than the Constant CR.LT, and the current Case Value - // is CR.LT - 1, then we can branch directly to the target block for - // the current Case Value, rather than emitting a RHS leaf node for it. - if ((RHSR.second - RHSR.first) == 1 && CR.LT && - RHSR.first->Low->getValue() == (CR.LT->getValue() - 1LL)) { - FalseBB = RHSR.first->BB; - } else { - FalseBB = CurMF->CreateMachineBasicBlock(LLVMBB); - CurMF->insert(BBI, FalseBB); - WorkList.push_back(CaseRec(FalseBB, CR.LT, C, RHSR)); - - // Put SV in a virtual register to make it available from the new blocks. - ExportFromCurrentBlock(SV); - } - - // Create a CaseBlock record representing a conditional branch to - // the LHS node if the value being switched on SV is less than C. - // Otherwise, branch to LHS. - CaseBlock CB(ISD::SETLT, SV, C, nullptr, TrueBB, FalseBB, CR.CaseBB); - - if (CR.CaseBB == SwitchBB) - visitSwitchCase(CB, SwitchBB); - else - SwitchCases.push_back(CB); -} - -/// handleBitTestsSwitchCase - if current case range has few destination and -/// range span less, than machine word bitwidth, encode case range into series -/// of masks and emit bit tests with these masks. -bool SelectionDAGBuilder::handleBitTestsSwitchCase(CaseRec& CR, - CaseRecVector& WorkList, - const Value* SV, - MachineBasicBlock* Default, - MachineBasicBlock* SwitchBB) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT PTy = TLI.getPointerTy(); - unsigned IntPtrBits = PTy.getSizeInBits(); - - Case& FrontCase = *CR.Range.first; - Case& BackCase = *(CR.Range.second-1); - - // Get the MachineFunction which holds the current MBB. This is used when - // inserting any additional MBBs necessary to represent the switch. - MachineFunction *CurMF = FuncInfo.MF; - - // If target does not have legal shift left, do not emit bit tests at all. - if (!TLI.isOperationLegal(ISD::SHL, PTy)) - return false; - - size_t numCmps = 0; - for (CaseItr I = CR.Range.first, E = CR.Range.second; I != E; ++I) { - // Single case counts one, case range - two. - numCmps += (I->Low == I->High ? 1 : 2); - } - - // Count unique destinations - SmallSet<MachineBasicBlock*, 4> Dests; - for (CaseItr I = CR.Range.first, E = CR.Range.second; I != E; ++I) { - Dests.insert(I->BB); - if (Dests.size() > 3) - // Don't bother the code below, if there are too much unique destinations - return false; - } - DEBUG(dbgs() << "Total number of unique destinations: " - << Dests.size() << '\n' - << "Total number of comparisons: " << numCmps << '\n'); - - // Compute span of values. - const APInt& minValue = FrontCase.Low->getValue(); - const APInt& maxValue = BackCase.High->getValue(); - APInt cmpRange = maxValue - minValue; - - DEBUG(dbgs() << "Compare range: " << cmpRange << '\n' - << "Low bound: " << minValue << '\n' - << "High bound: " << maxValue << '\n'); - - if (cmpRange.uge(IntPtrBits) || - (!(Dests.size() == 1 && numCmps >= 3) && - !(Dests.size() == 2 && numCmps >= 5) && - !(Dests.size() >= 3 && numCmps >= 6))) - return false; - - DEBUG(dbgs() << "Emitting bit tests\n"); - APInt lowBound = APInt::getNullValue(cmpRange.getBitWidth()); - - // Optimize the case where all the case values fit in a - // word without having to subtract minValue. In this case, - // we can optimize away the subtraction. - if (minValue.isNonNegative() && maxValue.slt(IntPtrBits)) { - cmpRange = maxValue; - } else { - lowBound = minValue; - } - - CaseBitsVector CasesBits; - unsigned i, count = 0; - - for (CaseItr I = CR.Range.first, E = CR.Range.second; I!=E; ++I) { - MachineBasicBlock* Dest = I->BB; - for (i = 0; i < count; ++i) - if (Dest == CasesBits[i].BB) - break; - - if (i == count) { - assert((count < 3) && "Too much destinations to test!"); - CasesBits.push_back(CaseBits(0, Dest, 0, 0/*Weight*/)); - count++; - } - - const APInt& lowValue = I->Low->getValue(); - const APInt& highValue = I->High->getValue(); - - uint64_t lo = (lowValue - lowBound).getZExtValue(); - uint64_t hi = (highValue - lowBound).getZExtValue(); - CasesBits[i].ExtraWeight += I->ExtraWeight; - - for (uint64_t j = lo; j <= hi; j++) { - CasesBits[i].Mask |= 1ULL << j; - CasesBits[i].Bits++; - } - - } - std::sort(CasesBits.begin(), CasesBits.end(), CaseBitsCmp()); - - BitTestInfo BTC; - - // Figure out which block is immediately after the current one. - MachineFunction::iterator BBI = CR.CaseBB; - ++BBI; - - const BasicBlock *LLVMBB = CR.CaseBB->getBasicBlock(); - - DEBUG(dbgs() << "Cases:\n"); - for (unsigned i = 0, e = CasesBits.size(); i!=e; ++i) { - DEBUG(dbgs() << "Mask: " << CasesBits[i].Mask - << ", Bits: " << CasesBits[i].Bits - << ", BB: " << CasesBits[i].BB << '\n'); - - MachineBasicBlock *CaseBB = CurMF->CreateMachineBasicBlock(LLVMBB); - CurMF->insert(BBI, CaseBB); - BTC.push_back(BitTestCase(CasesBits[i].Mask, - CaseBB, - CasesBits[i].BB, CasesBits[i].ExtraWeight)); - - // Put SV in a virtual register to make it available from the new blocks. - ExportFromCurrentBlock(SV); - } - - BitTestBlock BTB(lowBound, cmpRange, SV, - -1U, MVT::Other, (CR.CaseBB == SwitchBB), - CR.CaseBB, Default, std::move(BTC)); - - if (CR.CaseBB == SwitchBB) - visitBitTestHeader(BTB, SwitchBB); - - BitTestCases.push_back(std::move(BTB)); - - return true; -} - -void SelectionDAGBuilder::Clusterify(CaseVector &Cases, const SwitchInst *SI) { - BranchProbabilityInfo *BPI = FuncInfo.BPI; +void SelectionDAGBuilder::sortAndRangeify(CaseClusterVector &Clusters) { +#ifndef NDEBUG + for (const CaseCluster &CC : Clusters) + assert(CC.Low == CC.High && "Input clusters must be single-case"); +#endif - // Extract cases from the switch and sort them. - typedef std::pair<const ConstantInt*, unsigned> CasePair; - std::vector<CasePair> Sorted; - Sorted.reserve(SI->getNumCases()); - for (auto I : SI->cases()) - Sorted.push_back(std::make_pair(I.getCaseValue(), I.getSuccessorIndex())); - std::sort(Sorted.begin(), Sorted.end(), [](CasePair a, CasePair b) { - return a.first->getValue().slt(b.first->getValue()); + std::sort(Clusters.begin(), Clusters.end(), + [](const CaseCluster &a, const CaseCluster &b) { + return a.Low->getValue().slt(b.Low->getValue()); }); - // Merge adjacent cases with the same destination, build Cases vector. - assert(Cases.empty() && "Cases should be empty before Clusterify;"); - Cases.reserve(SI->getNumCases()); - MachineBasicBlock *PreviousSucc = nullptr; - for (CasePair &CP : Sorted) { - const ConstantInt *CaseVal = CP.first; - unsigned SuccIndex = CP.second; - MachineBasicBlock *Succ = FuncInfo.MBBMap[SI->getSuccessor(SuccIndex)]; - uint32_t Weight = BPI ? BPI->getEdgeWeight(SI->getParent(), SuccIndex) : 0; - - if (PreviousSucc == Succ && - (CaseVal->getValue() - Cases.back().High->getValue()) == 1) { + // Merge adjacent clusters with the same destination. + const unsigned N = Clusters.size(); + unsigned DstIndex = 0; + for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) { + CaseCluster &CC = Clusters[SrcIndex]; + const ConstantInt *CaseVal = CC.Low; + MachineBasicBlock *Succ = CC.MBB; + + if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ && + (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) { // If this case has the same successor and is a neighbour, merge it into // the previous cluster. - Cases.back().High = CaseVal; - Cases.back().ExtraWeight += Weight; + Clusters[DstIndex - 1].High = CaseVal; + Clusters[DstIndex - 1].Weight += CC.Weight; } else { - Cases.push_back(Case(CaseVal, CaseVal, Succ, Weight)); + std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex], + sizeof(Clusters[SrcIndex])); } - - PreviousSucc = Succ; } - - DEBUG({ - size_t numCmps = 0; - for (auto &I : Cases) - // A range counts double, since it requires two compares. - numCmps += I.Low != I.High ? 2 : 1; - - dbgs() << "Clusterify finished. Total clusters: " << Cases.size() - << ". Total compares: " << numCmps << '\n'; - }); + Clusters.resize(DstIndex); } void SelectionDAGBuilder::UpdateSplitBlock(MachineBasicBlock *First, @@ -2702,90 +2147,6 @@ void SelectionDAGBuilder::UpdateSplitBlock(MachineBasicBlock *First, BitTestCases[i].Parent = Last; } -void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) { - MachineBasicBlock *SwitchMBB = FuncInfo.MBB; - - // Create a vector of Cases, sorted so that we can efficiently create a binary - // search tree from them. - CaseVector Cases; - Clusterify(Cases, &SI); - - // Get the default destination MBB. - MachineBasicBlock *Default = FuncInfo.MBBMap[SI.getDefaultDest()]; - - if (isa<UnreachableInst>(SI.getDefaultDest()->getFirstNonPHIOrDbg()) && - !Cases.empty()) { - // Replace an unreachable default destination with the most popular case - // destination. - DenseMap<const BasicBlock *, unsigned> Popularity; - unsigned MaxPop = 0; - const BasicBlock *MaxBB = nullptr; - for (auto I : SI.cases()) { - const BasicBlock *BB = I.getCaseSuccessor(); - if (++Popularity[BB] > MaxPop) { - MaxPop = Popularity[BB]; - MaxBB = BB; - } - } - - // Set new default. - assert(MaxPop > 0); - assert(MaxBB); - Default = FuncInfo.MBBMap[MaxBB]; - - // Remove cases that were pointing to the destination that is now the default. - Cases.erase(std::remove_if(Cases.begin(), Cases.end(), - [&](const Case &C) { return C.BB == Default; }), - Cases.end()); - } - - // If there is only the default destination, go there directly. - if (Cases.empty()) { - // Update machine-CFG edges. - SwitchMBB->addSuccessor(Default); - - // If this is not a fall-through branch, emit the branch. - if (Default != NextBlock(SwitchMBB)) { - DAG.setRoot(DAG.getNode(ISD::BR, getCurSDLoc(), MVT::Other, - getControlRoot(), DAG.getBasicBlock(Default))); - } - return; - } - - // Get the Value to be switched on. - const Value *SV = SI.getCondition(); - - // Push the initial CaseRec onto the worklist - CaseRecVector WorkList; - WorkList.push_back(CaseRec(SwitchMBB,nullptr,nullptr, - CaseRange(Cases.begin(),Cases.end()))); - - while (!WorkList.empty()) { - // Grab a record representing a case range to process off the worklist - CaseRec CR = WorkList.back(); - WorkList.pop_back(); - - if (handleBitTestsSwitchCase(CR, WorkList, SV, Default, SwitchMBB)) - continue; - - // If the range has few cases (two or less) emit a series of specific - // tests. - if (handleSmallSwitchRange(CR, WorkList, SV, Default, SwitchMBB)) - continue; - - // If the switch has more than N blocks, and is at least 40% dense, and the - // target supports indirect branches, then emit a jump table rather than - // lowering the switch to a binary tree of conditional branches. - // N defaults to 4 and is controlled via TLS.getMinimumJumpTableEntries(). - if (handleJTSwitchCase(CR, WorkList, SV, Default, SwitchMBB)) - continue; - - // Emit binary tree. We need to pick a pivot, and push left and right ranges - // onto the worklist. Leafs are handled via handleSmallSwitchRange() call. - handleBTSplitSwitchCase(CR, WorkList, SV, SwitchMBB); - } -} - void SelectionDAGBuilder::visitIndirectBr(const IndirectBrInst &I) { MachineBasicBlock *IndirectBrMBB = FuncInfo.MBB; @@ -7819,3 +7180,759 @@ void SelectionDAGBuilder::updateDAGForMaybeTailCall(SDValue MaybeTC) { HasTailCall = true; } +bool SelectionDAGBuilder::isDense(const CaseClusterVector &Clusters, + unsigned *TotalCases, unsigned First, + unsigned Last) { + assert(Last >= First); + assert(TotalCases[Last] >= TotalCases[First]); + + APInt LowCase = Clusters[First].Low->getValue(); + APInt HighCase = Clusters[Last].High->getValue(); + assert(LowCase.getBitWidth() == HighCase.getBitWidth()); + + // FIXME: A range of consecutive cases has 100% density, but only requires one + // comparison to lower. We should discriminate against such consecutive ranges + // in jump tables. + + uint64_t Diff = (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100); + uint64_t Range = Diff + 1; + + uint64_t NumCases = + TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]); + + assert(NumCases < UINT64_MAX / 100); + assert(Range >= NumCases); + + return NumCases * 100 >= Range * MinJumpTableDensity; +} + +static inline bool areJTsAllowed(const TargetLowering &TLI) { + return TLI.isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) || + TLI.isOperationLegalOrCustom(ISD::BRIND, MVT::Other); +} + +bool SelectionDAGBuilder::buildJumpTable(CaseClusterVector &Clusters, + unsigned First, unsigned Last, + const SwitchInst *SI, + MachineBasicBlock *DefaultMBB, + CaseCluster &JTCluster) { + assert(First <= Last); + + uint64_t Weight = 0; + unsigned NumCmps = 0; + std::vector<MachineBasicBlock*> Table; + DenseMap<MachineBasicBlock*, uint32_t> JTWeights; + for (unsigned I = First; I <= Last; ++I) { + assert(Clusters[I].Kind == CC_Range); + Weight += Clusters[I].Weight; + APInt Low = Clusters[I].Low->getValue(); + APInt High = Clusters[I].High->getValue(); + NumCmps += (Low == High) ? 1 : 2; + if (I != First) { + // Fill the gap between this and the previous cluster. + APInt PreviousHigh = Clusters[I - 1].High->getValue(); + assert(PreviousHigh.slt(Low)); + uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1; + for (uint64_t J = 0; J < Gap; J++) + Table.push_back(DefaultMBB); + } + for (APInt X = Low; X.sle(High); ++X) + Table.push_back(Clusters[I].MBB); + JTWeights[Clusters[I].MBB] += Clusters[I].Weight; + } + + unsigned NumDests = JTWeights.size(); + if (isSuitableForBitTests(NumDests, NumCmps, + Clusters[First].Low->getValue(), + Clusters[Last].High->getValue())) { + // Clusters[First..Last] should be lowered as bit tests instead. + return false; + } + + // Create the MBB that will load from and jump through the table. + // Note: We create it here, but it's not inserted into the function yet. + MachineFunction *CurMF = FuncInfo.MF; + MachineBasicBlock *JumpTableMBB = + CurMF->CreateMachineBasicBlock(SI->getParent()); + + // Add successors. Note: use table order for determinism. + SmallPtrSet<MachineBasicBlock *, 8> Done; + for (MachineBasicBlock *Succ : Table) { + if (Done.count(Succ)) + continue; + addSuccessorWithWeight(JumpTableMBB, Succ, JTWeights[Succ]); + Done.insert(Succ); + } + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding()) + ->createJumpTableIndex(Table); + + // Set up the jump table info. + JumpTable JT(-1U, JTI, JumpTableMBB, nullptr); + JumpTableHeader JTH(Clusters[First].Low->getValue(), + Clusters[Last].High->getValue(), SI->getCondition(), + nullptr, false); + JTCases.push_back(JumpTableBlock(JTH, JT)); + + JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High, + JTCases.size() - 1, Weight); + return true; +} + +void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters, + const SwitchInst *SI, + MachineBasicBlock *DefaultMBB) { +#ifndef NDEBUG + // Clusters must be non-empty, sorted, and only contain Range clusters. + assert(!Clusters.empty()); + for (CaseCluster &C : Clusters) + assert(C.Kind == CC_Range); + for (unsigned i = 1, e = Clusters.size(); i < e; ++i) + assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue())); +#endif + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!areJTsAllowed(TLI)) + return; + + const int64_t N = Clusters.size(); + const unsigned MinJumpTableSize = TLI.getMinimumJumpTableEntries(); + + // Split Clusters into minimum number of dense partitions. The algorithm uses + // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code + // for the Case Statement'" (1994), but builds the MinPartitions array in + // reverse order to make it easier to reconstruct the partitions in ascending + // order. In the choice between two optimal partitionings, it picks the one + // which yields more jump tables. + + // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. + SmallVector<unsigned, 8> MinPartitions(N); + // LastElement[i] is the last element of the partition starting at i. + SmallVector<unsigned, 8> LastElement(N); + // NumTables[i]: nbr of >= MinJumpTableSize partitions from Clusters[i..N-1]. + SmallVector<unsigned, 8> NumTables(N); + // TotalCases[i]: Total nbr of cases in Clusters[0..i]. + SmallVector<unsigned, 8> TotalCases(N); + + for (unsigned i = 0; i < N; ++i) { + APInt Hi = Clusters[i].High->getValue(); + APInt Lo = Clusters[i].Low->getValue(); + TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; + if (i != 0) + TotalCases[i] += TotalCases[i - 1]; + } + + // Base case: There is only one way to partition Clusters[N-1]. + MinPartitions[N - 1] = 1; + LastElement[N - 1] = N - 1; + assert(MinJumpTableSize > 1); + NumTables[N - 1] = 0; + + // Note: loop indexes are signed to avoid underflow. + for (int64_t i = N - 2; i >= 0; i--) { + // Find optimal partitioning of Clusters[i..N-1]. + // Baseline: Put Clusters[i] into a partition on its own. + MinPartitions[i] = MinPartitions[i + 1] + 1; + LastElement[i] = i; + NumTables[i] = NumTables[i + 1]; + + // Search for a solution that results in fewer partitions. + for (int64_t j = N - 1; j > i; j--) { + // Try building a partition from Clusters[i..j]. + if (isDense(Clusters, &TotalCases[0], i, j)) { + unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); + bool IsTable = j - i + 1 >= MinJumpTableSize; + unsigned Tables = IsTable + (j == N - 1 ? 0 : NumTables[j + 1]); + + // If this j leads to fewer partitions, or same number of partitions + // with more lookup tables, it is a better partitioning. + if (NumPartitions < MinPartitions[i] || + (NumPartitions == MinPartitions[i] && Tables > NumTables[i])) { + MinPartitions[i] = NumPartitions; + LastElement[i] = j; + NumTables[i] = Tables; + } + } + } + } + + // Iterate over the partitions, replacing some with jump tables in-place. + unsigned DstIndex = 0; + for (unsigned First = 0, Last; First < N; First = Last + 1) { + Last = LastElement[First]; + assert(Last >= First); + assert(DstIndex <= First); + unsigned NumClusters = Last - First + 1; + + CaseCluster JTCluster; + if (NumClusters >= MinJumpTableSize && + buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) { + Clusters[DstIndex++] = JTCluster; + } else { + for (unsigned I = First; I <= Last; ++I) + std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I])); + } + } + Clusters.resize(DstIndex); +} + +bool SelectionDAGBuilder::rangeFitsInWord(const APInt &Low, const APInt &High) { + // FIXME: Using the pointer type doesn't seem ideal. + uint64_t BW = DAG.getTargetLoweringInfo().getPointerTy().getSizeInBits(); + uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1; + return Range <= BW; +} + +bool SelectionDAGBuilder::isSuitableForBitTests(unsigned NumDests, + unsigned NumCmps, + const APInt &Low, + const APInt &High) { + // FIXME: I don't think NumCmps is the correct metric: a single case and a + // range of cases both require only one branch to lower. Just looking at the + // number of clusters and destinations should be enough to decide whether to + // build bit tests. + + // To lower a range with bit tests, the range must fit the bitwidth of a + // machine word. + if (!rangeFitsInWord(Low, High)) + return false; + + // Decide whether it's profitable to lower this range with bit tests. Each + // destination requires a bit test and branch, and there is an overall range + // check branch. For a small number of clusters, separate comparisons might be + // cheaper, and for many destinations, splitting the range might be better. + return (NumDests == 1 && NumCmps >= 3) || + (NumDests == 2 && NumCmps >= 5) || + (NumDests == 3 && NumCmps >= 6); +} + +bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters, + unsigned First, unsigned Last, + const SwitchInst *SI, + CaseCluster &BTCluster) { + assert(First <= Last); + if (First == Last) + return false; + + BitVector Dests(FuncInfo.MF->getNumBlockIDs()); + unsigned NumCmps = 0; + for (int64_t I = First; I <= Last; ++I) { + assert(Clusters[I].Kind == CC_Range); + Dests.set(Clusters[I].MBB->getNumber()); + NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2; + } + unsigned NumDests = Dests.count(); + + APInt Low = Clusters[First].Low->getValue(); + APInt High = Clusters[Last].High->getValue(); + assert(Low.slt(High)); + + if (!isSuitableForBitTests(NumDests, NumCmps, Low, High)) + return false; + + APInt LowBound; + APInt CmpRange; + + const int BitWidth = + DAG.getTargetLoweringInfo().getPointerTy().getSizeInBits(); + assert((High - Low + 1).sle(BitWidth) && "Case range must fit in bit mask!"); + + if (Low.isNonNegative() && High.slt(BitWidth)) { + // Optimize the case where all the case values fit in a + // word without having to subtract minValue. In this case, + // we can optimize away the subtraction. + LowBound = APInt::getNullValue(Low.getBitWidth()); + CmpRange = High; + } else { + LowBound = Low; + CmpRange = High - Low; + } + + CaseBitsVector CBV; + uint64_t TotalWeight = 0; + for (unsigned i = First; i <= Last; ++i) { + // Find the CaseBits for this destination. + unsigned j; + for (j = 0; j < CBV.size(); ++j) + if (CBV[j].BB == Clusters[i].MBB) + break; + if (j == CBV.size()) + CBV.push_back(CaseBits(0, Clusters[i].MBB, 0, 0)); + CaseBits *CB = &CBV[j]; + + // Update Mask, Bits and ExtraWeight. + uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue(); + uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue(); + for (uint64_t j = Lo; j <= Hi; ++j) { + CB->Mask |= 1ULL << j; + CB->Bits++; + } + CB->ExtraWeight += Clusters[i].Weight; + TotalWeight += Clusters[i].Weight; + } + + BitTestInfo BTI; + std::sort(CBV.begin(), CBV.end(), [](const CaseBits &a, const CaseBits &b) { + // FIXME: Sort by weight. + return a.Bits > b.Bits; + }); + + for (auto &CB : CBV) { + MachineBasicBlock *BitTestBB = + FuncInfo.MF->CreateMachineBasicBlock(SI->getParent()); + BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraWeight)); + } + BitTestCases.push_back(BitTestBlock(LowBound, CmpRange, SI->getCondition(), + -1U, MVT::Other, false, nullptr, + nullptr, std::move(BTI))); + + BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High, + BitTestCases.size() - 1, TotalWeight); + return true; +} + +void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters, + const SwitchInst *SI) { +// Partition Clusters into as few subsets as possible, where each subset has a +// range that fits in a machine word and has <= 3 unique destinations. + +#ifndef NDEBUG + // Clusters must be sorted and contain Range or JumpTable clusters. + assert(!Clusters.empty()); + assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable); + for (const CaseCluster &C : Clusters) + assert(C.Kind == CC_Range || C.Kind == CC_JumpTable); + for (unsigned i = 1; i < Clusters.size(); ++i) + assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue())); +#endif + + // If target does not have legal shift left, do not emit bit tests at all. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT PTy = TLI.getPointerTy(); + if (!TLI.isOperationLegal(ISD::SHL, PTy)) + return; + + int BitWidth = PTy.getSizeInBits(); + const int64_t N = Clusters.size(); + + // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1]. + SmallVector<unsigned, 8> MinPartitions(N); + // LastElement[i] is the last element of the partition starting at i. + SmallVector<unsigned, 8> LastElement(N); + + // FIXME: This might not be the best algorithm for finding bit test clusters. + + // Base case: There is only one way to partition Clusters[N-1]. + MinPartitions[N - 1] = 1; + LastElement[N - 1] = N - 1; + + // Note: loop indexes are signed to avoid underflow. + for (int64_t i = N - 2; i >= 0; --i) { + // Find optimal partitioning of Clusters[i..N-1]. + // Baseline: Put Clusters[i] into a partition on its own. + MinPartitions[i] = MinPartitions[i + 1] + 1; + LastElement[i] = i; + + // Search for a solution that results in fewer partitions. + // Note: the search is limited by BitWidth, reducing time complexity. + for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) { + // Try building a partition from Clusters[i..j]. + + // Check the range. + if (!rangeFitsInWord(Clusters[i].Low->getValue(), + Clusters[j].High->getValue())) + continue; + + // Check nbr of destinations and cluster types. + // FIXME: This works, but doesn't seem very efficient. + bool RangesOnly = true; + BitVector Dests(FuncInfo.MF->getNumBlockIDs()); + for (int64_t k = i; k <= j; k++) { + if (Clusters[k].Kind != CC_Range) { + RangesOnly = false; + break; + } + Dests.set(Clusters[k].MBB->getNumber()); + } + if (!RangesOnly || Dests.count() > 3) + break; + + // Check if it's a better partition. + unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]); + if (NumPartitions < MinPartitions[i]) { + // Found a better partition. + MinPartitions[i] = NumPartitions; + LastElement[i] = j; + } + } + } + + // Iterate over the partitions, replacing with bit-test clusters in-place. + unsigned DstIndex = 0; + for (unsigned First = 0, Last; First < N; First = Last + 1) { + Last = LastElement[First]; + assert(First <= Last); + assert(DstIndex <= First); + + CaseCluster BitTestCluster; + if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) { + Clusters[DstIndex++] = BitTestCluster; + } else { + for (unsigned I = First; I <= Last; ++I) + std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I])); + } + } + Clusters.resize(DstIndex); +} + +void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond, + MachineBasicBlock *SwitchMBB, + MachineBasicBlock *DefaultMBB) { + MachineFunction *CurMF = FuncInfo.MF; + MachineBasicBlock *NextMBB = nullptr; + MachineFunction::iterator BBI = W.MBB; + if (++BBI != FuncInfo.MF->end()) + NextMBB = BBI; + + unsigned Size = W.LastCluster - W.FirstCluster + 1; + + BranchProbabilityInfo *BPI = FuncInfo.BPI; + + if (Size == 2 && W.MBB == SwitchMBB) { + // If any two of the cases has the same destination, and if one value + // is the same as the other, but has one bit unset that the other has set, + // use bit manipulation to do two compares at once. For example: + // "if (X == 6 || X == 4)" -> "if ((X|2) == 6)" + // TODO: This could be extended to merge any 2 cases in switches with 3 + // cases. + // TODO: Handle cases where W.CaseBB != SwitchBB. + CaseCluster &Small = *W.FirstCluster; + CaseCluster &Big = *W.LastCluster; + + if (Small.Low == Small.High && Big.Low == Big.High && + Small.MBB == Big.MBB) { + const APInt &SmallValue = Small.Low->getValue(); + const APInt &BigValue = Big.Low->getValue(); + + // Check that there is only one bit different. + if (BigValue.countPopulation() == SmallValue.countPopulation() + 1 && + (SmallValue | BigValue) == BigValue) { + // Isolate the common bit. + APInt CommonBit = BigValue & ~SmallValue; + assert((SmallValue | CommonBit) == BigValue && + CommonBit.countPopulation() == 1 && "Not a common bit?"); + + SDValue CondLHS = getValue(Cond); + EVT VT = CondLHS.getValueType(); + SDLoc DL = getCurSDLoc(); + + SDValue Or = DAG.getNode(ISD::OR, DL, VT, CondLHS, + DAG.getConstant(CommonBit, VT)); + SDValue Cond = DAG.getSetCC(DL, MVT::i1, Or, + DAG.getConstant(BigValue, VT), ISD::SETEQ); + + // Update successor info. + // Both Small and Big will jump to Small.BB, so we sum up the weights. + addSuccessorWithWeight(SwitchMBB, Small.MBB, Small.Weight + Big.Weight); + addSuccessorWithWeight( + SwitchMBB, DefaultMBB, + // The default destination is the first successor in IR. + BPI ? BPI->getEdgeWeight(SwitchMBB->getBasicBlock(), (unsigned)0) + : 0); + + // Insert the true branch. + SDValue BrCond = + DAG.getNode(ISD::BRCOND, DL, MVT::Other, getControlRoot(), Cond, + DAG.getBasicBlock(Small.MBB)); + // Insert the false branch. + BrCond = DAG.getNode(ISD::BR, DL, MVT::Other, BrCond, + DAG.getBasicBlock(DefaultMBB)); + + DAG.setRoot(BrCond); + return; + } + } + } + + if (TM.getOptLevel() != CodeGenOpt::None) { + // Order cases by weight so the most likely case will be checked first. + std::sort(W.FirstCluster, W.LastCluster + 1, + [](const CaseCluster &a, const CaseCluster &b) { + return a.Weight > b.Weight; + }); + + // Rearrange the case blocks so that the last one falls through if possible. + // Start at the bottom as that's the case with the lowest weight. + // FIXME: Take branch probability into account. + for (CaseClusterIt I = W.LastCluster - 1; I >= W.FirstCluster; --I) { + if (I->Kind == CC_Range && I->MBB == NextMBB) { + std::swap(*I, *W.LastCluster); + break; + } + } + } + + // Compute total weight. + uint32_t UnhandledWeights = 0; + for (CaseClusterIt I = W.FirstCluster; I <= W.LastCluster; ++I) + UnhandledWeights += I->Weight; + + MachineBasicBlock *CurMBB = W.MBB; + for (CaseClusterIt I = W.FirstCluster, E = W.LastCluster; I <= E; ++I) { + MachineBasicBlock *Fallthrough; + if (I == W.LastCluster) { + // For the last cluster, fall through to the default destination. + Fallthrough = DefaultMBB; + } else { + Fallthrough = CurMF->CreateMachineBasicBlock(CurMBB->getBasicBlock()); + CurMF->insert(BBI, Fallthrough); + // Put Cond in a virtual register to make it available from the new blocks. + ExportFromCurrentBlock(Cond); + } + + switch (I->Kind) { + case CC_JumpTable: { + // FIXME: Optimize away range check based on pivot comparisons. + JumpTableHeader *JTH = &JTCases[I->JTCasesIndex].first; + JumpTable *JT = &JTCases[I->JTCasesIndex].second; + + // The jump block hasn't been inserted yet; insert it here. + MachineBasicBlock *JumpMBB = JT->MBB; + CurMF->insert(BBI, JumpMBB); + addSuccessorWithWeight(CurMBB, Fallthrough); + addSuccessorWithWeight(CurMBB, JumpMBB); + + // The jump table header will be inserted in our current block, do the + // range check, and fall through to our fallthrough block. + JTH->HeaderBB = CurMBB; + JT->Default = Fallthrough; // FIXME: Move Default to JumpTableHeader. + + // If we're in the right place, emit the jump table header right now. + if (CurMBB == SwitchMBB) { + visitJumpTableHeader(*JT, *JTH, SwitchMBB); + JTH->Emitted = true; + } + break; + } + case CC_BitTests: { + // FIXME: Optimize away range check based on pivot comparisons. + BitTestBlock *BTB = &BitTestCases[I->BTCasesIndex]; + + // The bit test blocks haven't been inserted yet; insert them here. + for (BitTestCase &BTC : BTB->Cases) + CurMF->insert(BBI, BTC.ThisBB); + + // Fill in fields of the BitTestBlock. + BTB->Parent = CurMBB; + BTB->Default = Fallthrough; + + // If we're in the right place, emit the bit test header header right now. + if (CurMBB ==SwitchMBB) { + visitBitTestHeader(*BTB, SwitchMBB); + BTB->Emitted = true; + } + break; + } + case CC_Range: { + const Value *RHS, *LHS, *MHS; + ISD::CondCode CC; + if (I->Low == I->High) { + // Check Cond == I->Low. + CC = ISD::SETEQ; + LHS = Cond; + RHS=I->Low; + MHS = nullptr; + } else { + // Check I->Low <= Cond <= I->High. + CC = ISD::SETLE; + LHS = I->Low; + MHS = Cond; + RHS = I->High; + } + + // The false weight is the sum of all unhandled cases. + UnhandledWeights -= I->Weight; + CaseBlock CB(CC, LHS, RHS, MHS, I->MBB, Fallthrough, CurMBB, I->Weight, + UnhandledWeights); + + if (CurMBB == SwitchMBB) + visitSwitchCase(CB, SwitchMBB); + else + SwitchCases.push_back(CB); + + break; + } + } + CurMBB = Fallthrough; + } +} + +void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList, + const SwitchWorkListItem &W, + Value *Cond, + MachineBasicBlock *SwitchMBB) { + assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) && + "Clusters not sorted?"); + + unsigned NumClusters = W.LastCluster - W.FirstCluster + 1; + assert(NumClusters >= 2 && "Too small to split!"); + + // FIXME: When we have profile info, we might want to balance the tree based + // on weights instead of node count. + + CaseClusterIt PivotCluster = W.FirstCluster + NumClusters / 2; + CaseClusterIt FirstLeft = W.FirstCluster; + CaseClusterIt LastLeft = PivotCluster - 1; + CaseClusterIt FirstRight = PivotCluster; + CaseClusterIt LastRight = W.LastCluster; + const ConstantInt *Pivot = PivotCluster->Low; + + // New blocks will be inserted immediately after the current one. + MachineFunction::iterator BBI = W.MBB; + ++BBI; + + // We will branch to the LHS if Value < Pivot. If LHS is a single cluster, + // we can branch to its destination directly if it's squeezed exactly in + // between the known lower bound and Pivot - 1. + MachineBasicBlock *LeftMBB; + if (FirstLeft == LastLeft && FirstLeft->Kind == CC_Range && + FirstLeft->Low == W.GE && + (FirstLeft->High->getValue() + 1LL) == Pivot->getValue()) { + LeftMBB = FirstLeft->MBB; + } else { + LeftMBB = FuncInfo.MF->CreateMachineBasicBlock(W.MBB->getBasicBlock()); + FuncInfo.MF->insert(BBI, LeftMBB); + WorkList.push_back({LeftMBB, FirstLeft, LastLeft, W.GE, Pivot}); + // Put Cond in a virtual register to make it available from the new blocks. + ExportFromCurrentBlock(Cond); + } + + // Similarly, we will branch to the RHS if Value >= Pivot. If RHS is a + // single cluster, RHS.Low == Pivot, and we can branch to its destination + // directly if RHS.High equals the current upper bound. + MachineBasicBlock *RightMBB; + if (FirstRight == LastRight && FirstRight->Kind == CC_Range && + W.LT && (FirstRight->High->getValue() + 1ULL) == W.LT->getValue()) { + RightMBB = FirstRight->MBB; + } else { + RightMBB = FuncInfo.MF->CreateMachineBasicBlock(W.MBB->getBasicBlock()); + FuncInfo.MF->insert(BBI, RightMBB); + WorkList.push_back({RightMBB, FirstRight, LastRight, Pivot, W.LT}); + // Put Cond in a virtual register to make it available from the new blocks. + ExportFromCurrentBlock(Cond); + } + + // Create the CaseBlock record that will be used to lower the branch. + CaseBlock CB(ISD::SETLT, Cond, Pivot, nullptr, LeftMBB, RightMBB, W.MBB); + + if (W.MBB == SwitchMBB) + visitSwitchCase(CB, SwitchMBB); + else + SwitchCases.push_back(CB); +} + +void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) { + // Extract cases from the switch. + BranchProbabilityInfo *BPI = FuncInfo.BPI; + CaseClusterVector Clusters; + Clusters.reserve(SI.getNumCases()); + for (auto I : SI.cases()) { + MachineBasicBlock *Succ = FuncInfo.MBBMap[I.getCaseSuccessor()]; + const ConstantInt *CaseVal = I.getCaseValue(); + uint32_t Weight = 0; // FIXME: Use 1 instead? + if (BPI) + Weight = BPI->getEdgeWeight(SI.getParent(), I.getSuccessorIndex()); + Clusters.push_back(CaseCluster::range(CaseVal, CaseVal, Succ, Weight)); + } + + MachineBasicBlock *DefaultMBB = FuncInfo.MBBMap[SI.getDefaultDest()]; + + if (TM.getOptLevel() != CodeGenOpt::None) { + // Cluster adjacent cases with the same destination. + sortAndRangeify(Clusters); + + // Replace an unreachable default with the most popular destination. + // FIXME: Exploit unreachable default more aggressively. + bool UnreachableDefault = + isa<UnreachableInst>(SI.getDefaultDest()->getFirstNonPHIOrDbg()); + if (UnreachableDefault && !Clusters.empty()) { + DenseMap<const BasicBlock *, unsigned> Popularity; + unsigned MaxPop = 0; + const BasicBlock *MaxBB = nullptr; + for (auto I : SI.cases()) { + const BasicBlock *BB = I.getCaseSuccessor(); + if (++Popularity[BB] > MaxPop) { + MaxPop = Popularity[BB]; + MaxBB = BB; + } + } + // Set new default. + assert(MaxPop > 0 && MaxBB); + DefaultMBB = FuncInfo.MBBMap[MaxBB]; + + // Remove cases that were pointing to the destination that is now the + // default. + CaseClusterVector New; + New.reserve(Clusters.size()); + for (CaseCluster &CC : Clusters) { + if (CC.MBB != DefaultMBB) + New.push_back(CC); + } + Clusters = std::move(New); + } + } + + // If there is only the default destination, jump there directly. + MachineBasicBlock *SwitchMBB = FuncInfo.MBB; + if (Clusters.empty()) { + SwitchMBB->addSuccessor(DefaultMBB); + if (DefaultMBB != NextBlock(SwitchMBB)) { + DAG.setRoot(DAG.getNode(ISD::BR, getCurSDLoc(), MVT::Other, + getControlRoot(), DAG.getBasicBlock(DefaultMBB))); + } + return; + } + + if (TM.getOptLevel() != CodeGenOpt::None) { + findJumpTables(Clusters, &SI, DefaultMBB); + findBitTestClusters(Clusters, &SI); + } + + + DEBUG({ + dbgs() << "Case clusters: "; + for (const CaseCluster &C : Clusters) { + if (C.Kind == CC_JumpTable) dbgs() << "JT:"; + if (C.Kind == CC_BitTests) dbgs() << "BT:"; + + C.Low->getValue().print(dbgs(), true); + if (C.Low != C.High) { + dbgs() << '-'; + C.High->getValue().print(dbgs(), true); + } + dbgs() << ' '; + } + dbgs() << '\n'; + }); + + assert(!Clusters.empty()); + SwitchWorkList WorkList; + CaseClusterIt First = Clusters.begin(); + CaseClusterIt Last = Clusters.end() - 1; + WorkList.push_back({SwitchMBB, First, Last, nullptr, nullptr}); + + while (!WorkList.empty()) { + SwitchWorkListItem W = WorkList.back(); + WorkList.pop_back(); + unsigned NumClusters = W.LastCluster - W.FirstCluster + 1; + + if (NumClusters > 3 && TM.getOptLevel() != CodeGenOpt::None) { + // For optimized builds, lower large range as a balanced binary tree. + splitWorkItem(WorkList, W, SI.getCondition(), SwitchMBB); + continue; + } + + lowerWorkItem(W, SI.getCondition(), SwitchMBB, DefaultMBB); + } +} |