diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp | 105 |
1 files changed, 84 insertions, 21 deletions
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 5a84d74683a..8c9bd7e72b3 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -253,8 +253,11 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, /// (splitting the exit block as necessary). It simplifies the branch within /// the loop to an unconditional branch but doesn't remove it entirely. Further /// cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); @@ -318,6 +321,16 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, } }); + // If we have scalar evolutions, we need to invalidate them including this + // loop and the loop containing the exit block. + if (SE) { + if (Loop *ExitL = LI.getLoopFor(LoopExitBB)) + SE->forgetLoop(ExitL); + else + // Forget the entire nest as this exits the entire nest. + SE->forgetTopmostLoop(&L); + } + // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional // branch on LoopCond. @@ -420,8 +433,11 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, /// switch will not be revisited. If after unswitching there is only a single /// in-loop successor, the switch is further simplified to an unconditional /// branch. Still more cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); Value *LoopCond = SI.getCondition(); @@ -448,18 +464,33 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n"); + // We may need to invalidate SCEVs for the outermost loop reached by any of + // the exits. + Loop *OuterL = &L; + SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { auto CaseI = SI.case_begin() + Index; + // Compute the outer loop from this exit. + Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor()); + if (!ExitL || ExitL->contains(OuterL)) + OuterL = ExitL; // Save the value of this case. ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); // Delete the unswitched cases. SI.removeCase(CaseI); } + if (SE) { + if (OuterL) + SE->forgetLoop(OuterL); + else + SE->forgetTopmostLoop(&L); + } + // Check if after this all of the remaining cases point at the same // successor. BasicBlock *CommonSuccBB = nullptr; @@ -617,8 +648,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, /// /// The return value indicates whether anything was unswitched (and therefore /// changed). +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { bool Changed = false; // If loop header has only one reachable successor we should keep looking for @@ -652,7 +686,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, if (isa<Constant>(SI->getCondition())) return Changed; - if (!unswitchTrivialSwitch(L, *SI, DT, LI)) + if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE)) // Couldn't unswitch this one so we're done. return Changed; @@ -684,7 +718,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, // Found a trivial condition candidate: non-foldable conditional branch. If // we fail to unswitch this, we can't do anything else that is trivial. - if (!unswitchTrivialBranch(L, *BI, DT, LI)) + if (!unswitchTrivialBranch(L, *BI, DT, LI, SE)) return Changed; // Mark that we managed to unswitch something. @@ -1622,7 +1656,8 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { static bool unswitchNontrivialInvariants( Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) { + function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast<BranchInst>(&TI); SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI); @@ -1705,6 +1740,16 @@ static bool unswitchNontrivialInvariants( OuterExitL = NewOuterExitL; } + // At this point, we're definitely going to unswitch something so invalidate + // any cached information in ScalarEvolution for the outer most loop + // containing an exit block and all nested loops. + if (SE) { + if (OuterExitL) + SE->forgetLoop(OuterExitL); + else + SE->forgetTopmostLoop(&L); + } + // If the edge from this terminator to a successor dominates that successor, // store a map from each block in its dominator subtree to it. This lets us // tell when cloning for a particular successor if a block is dominated by @@ -1968,10 +2013,11 @@ computeDomSubtreeCost(DomTreeNode &N, return Cost; } -static bool unswitchBestCondition( - Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, - function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) { +static bool +unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4> @@ -2164,7 +2210,7 @@ static bool unswitchBestCondition( << BestUnswitchCost << ") terminator: " << *BestUnswitchTI << "\n"); return unswitchNontrivialInvariants( - L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB); + L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE); } /// Unswitch control flow predicated on loop invariant conditions. @@ -2173,10 +2219,25 @@ static bool unswitchBestCondition( /// require duplicating any part of the loop) out of the loop body. It then /// looks at other loop invariant control flows and tries to unswitch those as /// well by cloning the loop if the result is small enough. -static bool -unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, bool NonTrivial, - function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) { +/// +/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also +/// updated based on the unswitch. +/// +/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is +/// true, we will attempt to do non-trivial unswitching as well as trivial +/// unswitching. +/// +/// The `UnswitchCB` callback provided will be run after unswitching is +/// complete, with the first parameter set to `true` if the provided loop +/// remains a loop, and a list of new sibling loops created. +/// +/// If `SE` is non-null, we will update that analysis based on the unswitching +/// done. +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + bool NonTrivial, + function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); bool Changed = false; @@ -2186,7 +2247,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - if (unswitchAllTrivialConditions(L, DT, LI)) { + if (unswitchAllTrivialConditions(L, DT, LI, SE)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. UnswitchCB(/*CurrentLoopValid*/ true, {}); @@ -2207,7 +2268,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB)) + if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE)) return true; // No other opportunities to unswitch. @@ -2241,8 +2302,8 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, U.markLoopAsDeleted(L, LoopName); }; - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, - UnswitchCB)) + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB, + &AR.SE)) return PreservedAnalyses::all(); // Historically this pass has had issues with the dominator tree so verify it @@ -2290,6 +2351,9 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + auto *SE = SEWP ? &SEWP->getSE() : nullptr; + auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, ArrayRef<Loop *> NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. @@ -2305,8 +2369,7 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { LPM.markLoopAsDeleted(*L); }; - bool Changed = - unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB); + bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE); // If anything was unswitched, also clear any cached information about this // loop. |

