diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 80 |
1 files changed, 49 insertions, 31 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index a9f404f3dff..f7d4e5191ee 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6293,6 +6293,7 @@ ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { BackedgeTakenInfo Result = computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + addToLoopUseLists(Result, L); return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); } @@ -6368,6 +6369,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // recusive call to getBackedgeTakenInfo (on a different // loop), which would invalidate the iterator computed // earlier. + addToLoopUseLists(Result, L); return BackedgeTakenCounts.find(L)->second = std::move(Result); } @@ -6405,8 +6407,14 @@ void ScalarEvolution::forgetLoop(const Loop *L) { auto LoopUsersItr = LoopUsers.find(CurrL); if (LoopUsersItr != LoopUsers.end()) { - for (auto *S : LoopUsersItr->second) - forgetMemoizedResults(S); + for (auto LoopOrSCEV : LoopUsersItr->second) { + if (auto *S = LoopOrSCEV.dyn_cast<const SCEV *>()) + forgetMemoizedResults(S); + else { + BackedgeTakenCounts.erase(LoopOrSCEV.get<const Loop *>()); + PredicatedBackedgeTakenCounts.erase(LoopOrSCEV.get<const Loop *>()); + } + } LoopUsers.erase(LoopUsersItr); } @@ -6551,6 +6559,34 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, return false; } +static void findUsedLoopsInSCEVExpr(const SCEV *S, + SmallPtrSetImpl<const Loop *> &Result) { + struct FindUsedLoops { + SmallPtrSetImpl<const Loop *> &LoopsUsed; + FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed) + : LoopsUsed(LoopsUsed) {} + bool follow(const SCEV *S) { + if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) + LoopsUsed.insert(AR->getLoop()); + return true; + } + + bool isDone() const { return false; } + }; + FindUsedLoops F(Result); + SCEVTraversal<FindUsedLoops>(F).visitAll(S); +} + +void ScalarEvolution::BackedgeTakenInfo::findUsedLoops( + ScalarEvolution &SE, SmallPtrSetImpl<const Loop *> &Result) const { + if (auto *S = getMax()) + if (S != SE.getCouldNotCompute()) + findUsedLoopsInSCEVExpr(S, Result); + for (auto &ENT : ExitNotTaken) + if (ENT.ExactNotTaken != SE.getCouldNotCompute()) + findUsedLoopsInSCEVExpr(ENT.ExactNotTaken, Result); +} + ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) { assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || @@ -11034,21 +11070,6 @@ ScalarEvolution::forgetMemoizedResults(const SCEV *S, bool EraseExitLimit) { ++I; } - auto RemoveSCEVFromBackedgeMap = - [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { - for (auto I = Map.begin(), E = Map.end(); I != E;) { - BackedgeTakenInfo &BEInfo = I->second; - if (BEInfo.hasOperand(S, this)) { - BEInfo.clear(); - Map.erase(I++); - } else - ++I; - } - }; - - RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); - RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); - // TODO: There is a suspicion that we only need to do it when there is a // SCEVUnknown somewhere inside S. Need to check this. if (EraseExitLimit) @@ -11058,22 +11079,19 @@ ScalarEvolution::forgetMemoizedResults(const SCEV *S, bool EraseExitLimit) { } void ScalarEvolution::addToLoopUseLists(const SCEV *S) { - struct FindUsedLoops { - SmallPtrSet<const Loop *, 8> LoopsUsed; - bool follow(const SCEV *S) { - if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) - LoopsUsed.insert(AR->getLoop()); - return true; - } - - bool isDone() const { return false; } - }; + SmallPtrSet<const Loop *, 8> LoopsUsed; + findUsedLoopsInSCEVExpr(S, LoopsUsed); + for (auto *L : LoopsUsed) + LoopUsers[L].push_back({S}); +} - FindUsedLoops F; - SCEVTraversal<FindUsedLoops>(F).visitAll(S); +void ScalarEvolution::addToLoopUseLists( + const ScalarEvolution::BackedgeTakenInfo &BTI, const Loop *L) { + SmallPtrSet<const Loop *, 8> LoopsUsed; + BTI.findUsedLoops(*this, LoopsUsed); - for (auto *L : F.LoopsUsed) - LoopUsers[L].push_back(S); + for (auto *UsedL : LoopsUsed) + LoopUsers[UsedL].push_back({L}); } void ScalarEvolution::verify() const { |