diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LowerSwitch.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LowerSwitch.cpp | 205 |
1 files changed, 137 insertions, 68 deletions
diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp index fec13c8c160..08db63ed8d6 100644 --- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -16,8 +16,12 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -27,6 +31,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -77,6 +82,10 @@ namespace { bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LazyValueInfoWrapperPass>(); + } + struct CaseRange { ConstantInt* Low; ConstantInt* High; @@ -90,15 +99,18 @@ namespace { using CaseItr = std::vector<CaseRange>::iterator; private: - void processSwitchInst(SwitchInst *SI, SmallPtrSetImpl<BasicBlock*> &DeleteList); + void processSwitchInst(SwitchInst *SI, + SmallPtrSetImpl<BasicBlock *> &DeleteList, + AssumptionCache *AC, LazyValueInfo *LVI); BasicBlock *switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ConstantInt *UpperBound, Value *Val, BasicBlock *Predecessor, BasicBlock *OrigBlock, BasicBlock *Default, const std::vector<IntRange> &UnreachableRanges); - BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, - BasicBlock *Default); + BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, + ConstantInt *LowerBound, ConstantInt *UpperBound, + BasicBlock *OrigBlock, BasicBlock *Default); unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); }; @@ -120,8 +132,12 @@ char LowerSwitch::ID = 0; // Publicly exposed interface to pass... char &llvm::LowerSwitchID = LowerSwitch::ID; -INITIALIZE_PASS(LowerSwitch, "lowerswitch", - "Lower SwitchInst's to branches", false, false) +INITIALIZE_PASS_BEGIN(LowerSwitch, "lowerswitch", + "Lower SwitchInst's to branches", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(LowerSwitch, "lowerswitch", + "Lower SwitchInst's to branches", false, false) // createLowerSwitchPass - Interface to this file... FunctionPass *llvm::createLowerSwitchPass() { @@ -129,6 +145,17 @@ FunctionPass *llvm::createLowerSwitchPass() { } bool LowerSwitch::runOnFunction(Function &F) { + LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>(); + AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr; + // Prevent LazyValueInfo from using the DominatorTree as LowerSwitch does not + // preserve it and it becomes stale (when available) pretty much immediately. + // Currently the DominatorTree is only used by LowerSwitch indirectly via LVI + // and computeKnownBits to refine isValidAssumeForContext's results. Given + // that the latter can handle some of the simple cases w/o a DominatorTree, + // it's easier to refrain from using the tree than to keep it up to date. + LVI->disableDT(); + bool Changed = false; SmallPtrSet<BasicBlock*, 8> DeleteList; @@ -142,11 +169,12 @@ bool LowerSwitch::runOnFunction(Function &F) { if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { Changed = true; - processSwitchInst(SI, DeleteList); + processSwitchInst(SI, DeleteList, AC, LVI); } } for (BasicBlock* BB: DeleteList) { + LVI->eraseBlock(BB); DeleteDeadBlock(BB); } @@ -159,10 +187,11 @@ static raw_ostream &operator<<(raw_ostream &O, const LowerSwitch::CaseVector &C) { O << "["; - for (LowerSwitch::CaseVector::const_iterator B = C.begin(), - E = C.end(); B != E; ) { - O << *B->Low << " -" << *B->High; - if (++B != E) O << ", "; + for (LowerSwitch::CaseVector::const_iterator B = C.begin(), E = C.end(); + B != E;) { + O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]"; + if (++B != E) + O << ", "; } return O << "]"; @@ -178,8 +207,9 @@ static raw_ostream &operator<<(raw_ostream &O, /// 2) Removed if subsequent incoming values now share the same case, i.e., /// multiple outcome edges are condensed into one. This is necessary to keep the /// number of phi values equal to the number of branches to SuccBB. -static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, - unsigned NumMergedCases) { +static void +fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI()->getIterator(); I != IE; ++I) { @@ -221,6 +251,7 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, BasicBlock *Predecessor, BasicBlock *OrigBlock, BasicBlock *Default, const std::vector<IntRange> &UnreachableRanges) { + assert(LowerBound && UpperBound && "Bounds must be initialized"); unsigned Size = End - Begin; if (Size == 1) { @@ -230,13 +261,12 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { unsigned NumMergedCases = 0; - if (LowerBound && UpperBound) - NumMergedCases = - UpperBound->getSExtValue() - LowerBound->getSExtValue(); + NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } - return newLeafBlock(*Begin, Val, OrigBlock, Default); + return newLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock, + Default); } unsigned Mid = Size / 2; @@ -246,8 +276,8 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n"); CaseRange &Pivot = *(Begin + Mid); - LLVM_DEBUG(dbgs() << "Pivot ==> " << Pivot.Low->getValue() << " -" - << Pivot.High->getValue() << "\n"); + LLVM_DEBUG(dbgs() << "Pivot ==> [" << Pivot.Low->getValue() << ", " + << Pivot.High->getValue() << "]\n"); // NewLowerBound here should never be the integer minimal value. // This is because it is computed from a case range that is never @@ -269,14 +299,10 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, NewUpperBound = LHS.back().High; } - LLVM_DEBUG(dbgs() << "LHS Bounds ==> "; if (LowerBound) { - dbgs() << LowerBound->getSExtValue(); - } else { dbgs() << "NONE"; } dbgs() << " - " - << NewUpperBound->getSExtValue() << "\n"; - dbgs() << "RHS Bounds ==> "; - dbgs() << NewLowerBound->getSExtValue() << " - "; if (UpperBound) { - dbgs() << UpperBound->getSExtValue() << "\n"; - } else { dbgs() << "NONE\n"; }); + LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", " + << NewUpperBound->getSExtValue() << "]\n" + << "RHS Bounds ==> [" << NewLowerBound->getSExtValue() + << ", " << UpperBound->getSExtValue() << "]\n"); // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. @@ -304,9 +330,11 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, /// switch's value == the case's value. If not, then it jumps to the default /// branch. At this point in the tree, the value can't be another valid case /// value, so the jump to the "default" branch is warranted. -BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, - BasicBlock* OrigBlock, - BasicBlock* Default) { +BasicBlock *LowerSwitch::newLeafBlock(CaseRange &Leaf, Value *Val, + ConstantInt *LowerBound, + ConstantInt *UpperBound, + BasicBlock *OrigBlock, + BasicBlock *Default) { Function* F = OrigBlock->getParent(); BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); @@ -319,10 +347,14 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, Leaf.Low, "SwitchLeaf"); } else { // Make range comparison - if (Leaf.Low->isMinValue(true /*isSigned*/)) { + if (Leaf.Low == LowerBound) { // Val >= Min && Val <= Hi --> Val <= Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, "SwitchLeaf"); + } else if (Leaf.High == UpperBound) { + // Val <= Max && Val >= Lo --> Val >= Lo + Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, + "SwitchLeaf"); } else if (Leaf.Low->isZero()) { // Val >= 0 && Val <= Hi --> Val <=u Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, @@ -362,14 +394,20 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, return NewLeaf; } -/// Transform simple list of Cases into list of CaseRange's. +/// Transform simple list of \p SI's cases into list of CaseRange's \p Cases. +/// \post \p Cases wouldn't contain references to \p SI's default BB. +/// \returns Number of \p SI's cases that do not reference \p SI's default BB. unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { - unsigned numCmps = 0; + unsigned NumSimpleCases = 0; // Start with "simple" cases - for (auto Case : SI->cases()) + for (auto Case : SI->cases()) { + if (Case.getCaseSuccessor() == SI->getDefaultDest()) + continue; Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), Case.getCaseSuccessor())); + ++NumSimpleCases; + } llvm::sort(Cases, CaseCmp()); @@ -395,60 +433,94 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { Cases.erase(std::next(I), Cases.end()); } - for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { - if (I->Low != I->High) - // A range counts double, since it requires two compares. - ++numCmps; - } + return NumSimpleCases; +} - return numCmps; +static ConstantRange getConstantRangeFromKnownBits(const KnownBits &Known) { + APInt Lower = Known.One; + APInt Upper = ~Known.Zero + 1; + if (Upper == Lower) + return ConstantRange(Known.getBitWidth(), /*isFullSet=*/true); + return ConstantRange(Lower, Upper); } /// Replace the specified switch instruction with a sequence of chained if-then /// insts in a balanced binary search. void LowerSwitch::processSwitchInst(SwitchInst *SI, - SmallPtrSetImpl<BasicBlock*> &DeleteList) { - BasicBlock *CurBlock = SI->getParent(); - BasicBlock *OrigBlock = CurBlock; - Function *F = CurBlock->getParent(); + SmallPtrSetImpl<BasicBlock *> &DeleteList, + AssumptionCache *AC, LazyValueInfo *LVI) { + BasicBlock *OrigBlock = SI->getParent(); + Function *F = OrigBlock->getParent(); Value *Val = SI->getCondition(); // The value we are switching on... BasicBlock* Default = SI->getDefaultDest(); // Don't handle unreachable blocks. If there are successors with phis, this // would leave them behind with missing predecessors. - if ((CurBlock != &F->getEntryBlock() && pred_empty(CurBlock)) || - CurBlock->getSinglePredecessor() == CurBlock) { - DeleteList.insert(CurBlock); + if ((OrigBlock != &F->getEntryBlock() && pred_empty(OrigBlock)) || + OrigBlock->getSinglePredecessor() == OrigBlock) { + DeleteList.insert(OrigBlock); return; } + // Prepare cases vector. + CaseVector Cases; + const unsigned NumSimpleCases = Clusterify(Cases, SI); + LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total non-default cases: " << NumSimpleCases + << "\nCase clusters: " << Cases << "\n"); + // If there is only the default destination, just branch. - if (!SI->getNumCases()) { - BranchInst::Create(Default, CurBlock); + if (Cases.empty()) { + BranchInst::Create(Default, OrigBlock); + // Remove all the references from Default's PHIs to OrigBlock, but one. + fixPhis(Default, OrigBlock, OrigBlock); SI->eraseFromParent(); return; } - // Prepare cases vector. - CaseVector Cases; - unsigned numCmps = Clusterify(Cases, SI); - LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() - << ". Total compares: " << numCmps << "\n"); - LLVM_DEBUG(dbgs() << "Cases: " << Cases << "\n"); - (void)numCmps; - ConstantInt *LowerBound = nullptr; ConstantInt *UpperBound = nullptr; - std::vector<IntRange> UnreachableRanges; + bool DefaultIsUnreachableFromSwitch = false; if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) { // Make the bounds tightly fitted around the case value range, because we // know that the value passed to the switch must be exactly one of the case // values. - assert(!Cases.empty()); LowerBound = Cases.front().Low; UpperBound = Cases.back().High; + DefaultIsUnreachableFromSwitch = true; + } else { + // Constraining the range of the value being switched over helps eliminating + // unreachable BBs and minimizing the number of `add` instructions + // newLeafBlock ends up emitting. Running CorrelatedValuePropagation after + // LowerSwitch isn't as good, and also much more expensive in terms of + // compile time for the following reasons: + // 1. it processes many kinds of instructions, not just switches; + // 2. even if limited to icmp instructions only, it will have to process + // roughly C icmp's per switch, where C is the number of cases in the + // switch, while LowerSwitch only needs to call LVI once per switch. + const DataLayout &DL = F->getParent()->getDataLayout(); + KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI); + ConstantRange KnownBitsRange = getConstantRangeFromKnownBits(Known); + const ConstantRange LVIRange = LVI->getConstantRange(Val, OrigBlock, SI); + ConstantRange ValRange = KnownBitsRange.intersectWith(LVIRange); + // We delegate removal of unreachable non-default cases to other passes. In + // the unlikely event that some of them survived, we just conservatively + // maintain the invariant that all the cases lie between the bounds. This + // may, however, still render the default case effectively unreachable. + APInt Low = Cases.front().Low->getValue(); + APInt High = Cases.back().High->getValue(); + APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low); + APInt Max = APIntOps::smax(ValRange.getSignedMax(), High); + + LowerBound = ConstantInt::get(SI->getContext(), Min); + UpperBound = ConstantInt::get(SI->getContext(), Max); + DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max); + } + std::vector<IntRange> UnreachableRanges; + + if (DefaultIsUnreachableFromSwitch) { DenseMap<BasicBlock *, unsigned> Popularity; unsigned MaxPop = 0; BasicBlock *PopSucc = nullptr; @@ -495,8 +567,10 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, #endif // As the default block in the switch is unreachable, update the PHI nodes - // (remove the entry to the default block) to reflect this. - Default->removePredecessor(OrigBlock); + // (remove all of the references to the default block) to reflect this. + const unsigned NumDefaultEdges = SI->getNumCases() + 1 - NumSimpleCases; + for (unsigned I = 0; I < NumDefaultEdges; ++I) + Default->removePredecessor(OrigBlock); // Use the most popular block as the new default, reducing the number of // cases. @@ -509,7 +583,7 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, // If there are no cases left, just branch. if (Cases.empty()) { - BranchInst::Create(Default, CurBlock); + BranchInst::Create(Default, OrigBlock); SI->eraseFromParent(); // As all the cases have been replaced with a single branch, only keep // one entry in the PHI nodes. @@ -519,11 +593,6 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, } } - unsigned NrOfDefaults = (SI->getDefaultDest() == Default) ? 1 : 0; - for (const auto &Case : SI->cases()) - if (Case.getCaseSuccessor() == Default) - NrOfDefaults++; - // Create a new, empty default block so that the new hierarchy of // if-then statements go to this and the PHI nodes are happy. BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); @@ -536,14 +605,14 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, // If there are entries in any PHI nodes for the default edge, make sure // to update them as well. - fixPhis(Default, OrigBlock, NewDefault, NrOfDefaults); + fixPhis(Default, OrigBlock, NewDefault); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); // We are now done with the switch instruction, delete it. BasicBlock *OldDefault = SI->getDefaultDest(); - CurBlock->getInstList().erase(SI); + OrigBlock->getInstList().erase(SI); // If the Default block has no more predecessors just add it to DeleteList. if (pred_begin(OldDefault) == pred_end(OldDefault)) |