summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp197
1 files changed, 187 insertions, 10 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 32bf11ac5c5..65aa4b6ae69 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9627,17 +9627,40 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
return Eq;
}
+const SCEVPredicate *ScalarEvolution::getWrapPredicate(
+ const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Wrap);
+ ID.AddPointer(AR);
+ ID.AddInteger(AddedFlags);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ auto *OF = new (SCEVAllocator)
+ SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
+ UniquePreds.InsertNode(OF, IP);
+ return OF;
+}
+
namespace {
+
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public:
- static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
- SCEVUnionPredicate &A) {
- SCEVPredicateRewriter Rewriter(SE, A);
+ // Rewrites Scev in the context of a loop L and the predicate A.
+ // If Assume is true, rewrite is free to add further predicates to A
+ // such that the result will be an AddRecExpr.
+ static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ ScalarEvolution &SE, SCEVUnionPredicate &A,
+ bool Assume) {
+ SCEVPredicateRewriter Rewriter(L, SE, A, Assume);
return Rewriter.visit(Scev);
}
- SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
- : SCEVRewriteVisitor(SE), P(P) {}
+ SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
+ SCEVUnionPredicate &P, bool Assume)
+ : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
auto ExprPreds = P.getPredicatesForExpr(Expr);
@@ -9649,14 +9672,67 @@ public:
return Expr;
}
+ const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+ const SCEV *Operand = visit(Expr->getOperand());
+ const SCEVAddRecExpr *AR = dyn_cast<const SCEVAddRecExpr>(Operand);
+ if (AR && AR->getLoop() == L && AR->isAffine()) {
+ // This couldn't be folded because the operand didn't have the nuw
+ // flag. Add the nusw flag as an assumption that we could make.
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ Type *Ty = Expr->getType();
+ if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
+ return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
+ SE.getSignExtendExpr(Step, Ty), L,
+ AR->getNoWrapFlags());
+ }
+ return SE.getZeroExtendExpr(Operand, Expr->getType());
+ }
+
+ const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+ const SCEV *Operand = visit(Expr->getOperand());
+ const SCEVAddRecExpr *AR = dyn_cast<const SCEVAddRecExpr>(Operand);
+ if (AR && AR->getLoop() == L && AR->isAffine()) {
+ // This couldn't be folded because the operand didn't have the nsw
+ // flag. Add the nssw flag as an assumption that we could make.
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ Type *Ty = Expr->getType();
+ if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
+ return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
+ SE.getSignExtendExpr(Step, Ty), L,
+ AR->getNoWrapFlags());
+ }
+ return SE.getSignExtendExpr(Operand, Expr->getType());
+ }
+
private:
+ bool addOverflowAssumption(const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ auto *A = SE.getWrapPredicate(AR, AddedFlags);
+ if (!Assume) {
+ // Check if we've already made this assumption.
+ if (P.implies(A))
+ return true;
+ return false;
+ }
+ P.add(A);
+ return true;
+ }
+
SCEVUnionPredicate &P;
+ const Loop *L;
+ bool Assume;
};
} // end anonymous namespace
const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+ const Loop *L,
SCEVUnionPredicate &Preds) {
- return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+ return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, false);
+}
+
+const SCEV *ScalarEvolution::convertSCEVToAddRecWithPredicates(
+ const SCEV *Scev, const Loop *L, SCEVUnionPredicate &Preds) {
+ return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, true);
}
/// SCEV predicates
@@ -9686,6 +9762,59 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
}
+SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVAddRecExpr *AR,
+ IncrementWrapFlags Flags)
+ : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
+
+const SCEV *SCEVWrapPredicate::getExpr() const { return AR; }
+
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+
+ return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+}
+
+bool SCEVWrapPredicate::isAlwaysTrue() const {
+ SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
+ IncrementWrapFlags IFlags = Flags;
+
+ if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
+ IFlags = clearFlags(IFlags, IncrementNSSW);
+
+ return IFlags == IncrementAnyWrap;
+}
+
+void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << *getExpr() << " Added Flags: ";
+ if (SCEVWrapPredicate::IncrementNUSW & getFlags())
+ OS << "<nusw>";
+ if (SCEVWrapPredicate::IncrementNSSW & getFlags())
+ OS << "<nssw>";
+ OS << "\n";
+}
+
+SCEVWrapPredicate::IncrementWrapFlags
+SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
+ ScalarEvolution &SE) {
+ IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
+ SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
+
+ // We can safely transfer the NSW flag as NSSW.
+ if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
+ ImpliedFlags = IncrementNSSW;
+
+ if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
+ // If the increment is positive, the SCEV NUW flag will also imply the
+ // WrapPredicate NUSW flag.
+ if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
+ if (Step->getValue()->getValue().isNonNegative())
+ ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
+ }
+
+ return ImpliedFlags;
+}
+
/// Union predicates don't get cached so create a dummy set ID for it.
SCEVUnionPredicate::SCEVUnionPredicate()
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
@@ -9742,8 +9871,9 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {
Preds.push_back(N);
}
-PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE)
- : SE(SE), Generation(0) {}
+PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
+ Loop &L)
+ : SE(SE), L(L), Generation(0) {}
const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
const SCEV *Expr = SE.getSCEV(V);
@@ -9758,7 +9888,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
if (Entry.second)
Expr = Entry.second;
- const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds);
+ const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
Entry = {Generation, NewSCEV};
return NewSCEV;
@@ -9780,7 +9910,54 @@ void PredicatedScalarEvolution::updateGeneration() {
if (++Generation == 0) {
for (auto &II : RewriteMap) {
const SCEV *Rewritten = II.second.second;
- II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)};
+ II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
}
}
}
+
+void PredicatedScalarEvolution::setNoOverflow(
+ Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
+ const SCEV *Expr = getSCEV(V);
+ const auto *AR = cast<SCEVAddRecExpr>(Expr);
+
+ auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
+
+ // Clear the statically implied flags.
+ Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
+ addPredicate(*SE.getWrapPredicate(AR, Flags));
+
+ auto II = FlagsMap.insert({V, Flags});
+ if (!II.second)
+ II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
+}
+
+bool PredicatedScalarEvolution::hasNoOverflow(
+ Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
+ const SCEV *Expr = getSCEV(V);
+ const auto *AR = cast<SCEVAddRecExpr>(Expr);
+
+ Flags = SCEVWrapPredicate::clearFlags(
+ Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
+
+ auto II = FlagsMap.find(V);
+
+ if (II != FlagsMap.end())
+ Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
+
+ return Flags == SCEVWrapPredicate::IncrementAnyWrap;
+}
+
+const SCEV *PredicatedScalarEvolution::getAsAddRec(Value *V) {
+ const SCEV *Expr = this->getSCEV(V);
+ const SCEV *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds);
+ updateGeneration();
+ RewriteMap[SE.getSCEV(V)] = {Generation, New};
+ return New;
+}
+
+PredicatedScalarEvolution::
+PredicatedScalarEvolution(const PredicatedScalarEvolution &Init) :
+ RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds) {
+ for (auto I = Init.FlagsMap.begin(), E = Init.FlagsMap.end(); I != E; ++I)
+ FlagsMap.insert(*I);
+}
OpenPOWER on IntegriCloud