diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 197 |
1 files changed, 10 insertions, 187 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 65aa4b6ae69..32bf11ac5c5 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9627,40 +9627,17 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, return Eq; } -const SCEVPredicate *ScalarEvolution::getWrapPredicate( - const SCEVAddRecExpr *AR, - SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { - FoldingSetNodeID ID; - // Unique this node based on the arguments - ID.AddInteger(SCEVPredicate::P_Wrap); - ID.AddPointer(AR); - ID.AddInteger(AddedFlags); - void *IP = nullptr; - if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) - return S; - auto *OF = new (SCEVAllocator) - SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); - UniquePreds.InsertNode(OF, IP); - return OF; -} - namespace { - class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: - // Rewrites Scev in the context of a loop L and the predicate A. - // If Assume is true, rewrite is free to add further predicates to A - // such that the result will be an AddRecExpr. - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, - ScalarEvolution &SE, SCEVUnionPredicate &A, - bool Assume) { - SCEVPredicateRewriter Rewriter(L, SE, A, Assume); + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + SCEVUnionPredicate &A) { + SCEVPredicateRewriter Rewriter(SE, A); return Rewriter.visit(Scev); } - SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, - SCEVUnionPredicate &P, bool Assume) - : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) + : SCEVRewriteVisitor(SE), P(P) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); @@ -9672,67 +9649,14 @@ public: return Expr; } - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - const SCEVAddRecExpr *AR = dyn_cast<const SCEVAddRecExpr>(Operand); - if (AR && AR->getLoop() == L && AR->isAffine()) { - // This couldn't be folded because the operand didn't have the nuw - // flag. Add the nusw flag as an assumption that we could make. - const SCEV *Step = AR->getStepRecurrence(SE); - Type *Ty = Expr->getType(); - if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) - return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), - SE.getSignExtendExpr(Step, Ty), L, - AR->getNoWrapFlags()); - } - return SE.getZeroExtendExpr(Operand, Expr->getType()); - } - - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - const SCEVAddRecExpr *AR = dyn_cast<const SCEVAddRecExpr>(Operand); - if (AR && AR->getLoop() == L && AR->isAffine()) { - // This couldn't be folded because the operand didn't have the nsw - // flag. Add the nssw flag as an assumption that we could make. - const SCEV *Step = AR->getStepRecurrence(SE); - Type *Ty = Expr->getType(); - if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) - return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), - SE.getSignExtendExpr(Step, Ty), L, - AR->getNoWrapFlags()); - } - return SE.getSignExtendExpr(Operand, Expr->getType()); - } - private: - bool addOverflowAssumption(const SCEVAddRecExpr *AR, - SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { - auto *A = SE.getWrapPredicate(AR, AddedFlags); - if (!Assume) { - // Check if we've already made this assumption. - if (P.implies(A)) - return true; - return false; - } - P.add(A); - return true; - } - SCEVUnionPredicate &P; - const Loop *L; - bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, - const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, false); -} - -const SCEV *ScalarEvolution::convertSCEVToAddRecWithPredicates( - const SCEV *Scev, const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, true); + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); } /// SCEV predicates @@ -9762,59 +9686,6 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } -SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, - const SCEVAddRecExpr *AR, - IncrementWrapFlags Flags) - : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} - -const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } - -bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast<SCEVWrapPredicate>(N); - - return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; -} - -bool SCEVWrapPredicate::isAlwaysTrue() const { - SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); - IncrementWrapFlags IFlags = Flags; - - if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) - IFlags = clearFlags(IFlags, IncrementNSSW); - - return IFlags == IncrementAnyWrap; -} - -void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { - OS.indent(Depth) << *getExpr() << " Added Flags: "; - if (SCEVWrapPredicate::IncrementNUSW & getFlags()) - OS << "<nusw>"; - if (SCEVWrapPredicate::IncrementNSSW & getFlags()) - OS << "<nssw>"; - OS << "\n"; -} - -SCEVWrapPredicate::IncrementWrapFlags -SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, - ScalarEvolution &SE) { - IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; - SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); - - // We can safely transfer the NSW flag as NSSW. - if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) - ImpliedFlags = IncrementNSSW; - - if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { - // If the increment is positive, the SCEV NUW flag will also imply the - // WrapPredicate NUSW flag. - if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) - if (Step->getValue()->getValue().isNonNegative()) - ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); - } - - return ImpliedFlags; -} - /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9871,9 +9742,8 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, - Loop &L) - : SE(SE), L(L), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) + : SE(SE), Generation(0) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9888,7 +9758,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; @@ -9910,54 +9780,7 @@ void PredicatedScalarEvolution::updateGeneration() { if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; } } } - -void PredicatedScalarEvolution::setNoOverflow( - Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { - const SCEV *Expr = getSCEV(V); - const auto *AR = cast<SCEVAddRecExpr>(Expr); - - auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); - - // Clear the statically implied flags. - Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); - addPredicate(*SE.getWrapPredicate(AR, Flags)); - - auto II = FlagsMap.insert({V, Flags}); - if (!II.second) - II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); -} - -bool PredicatedScalarEvolution::hasNoOverflow( - Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { - const SCEV *Expr = getSCEV(V); - const auto *AR = cast<SCEVAddRecExpr>(Expr); - - Flags = SCEVWrapPredicate::clearFlags( - Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); - - auto II = FlagsMap.find(V); - - if (II != FlagsMap.end()) - Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); - - return Flags == SCEVWrapPredicate::IncrementAnyWrap; -} - -const SCEV *PredicatedScalarEvolution::getAsAddRec(Value *V) { - const SCEV *Expr = this->getSCEV(V); - const SCEV *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); - updateGeneration(); - RewriteMap[SE.getSCEV(V)] = {Generation, New}; - return New; -} - -PredicatedScalarEvolution:: -PredicatedScalarEvolution(const PredicatedScalarEvolution &Init) : - RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds) { - for (auto I = Init.FlagsMap.begin(), E = Init.FlagsMap.end(); I != E; ++I) - FlagsMap.insert(*I); -} |