summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/Scalar/LoopPredication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPredication.cpp258
1 files changed, 220 insertions, 38 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 9b12ba18044..84577dd182a 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -34,6 +34,120 @@
// else
// deoptimize
//
+// It's tempting to rely on SCEV here, but it has proven to be problematic.
+// Generally the facts SCEV provides about the increment step of add
+// recurrences are true if the backedge of the loop is taken, which implicitly
+// assumes that the guard doesn't fail. Using these facts to optimize the
+// guard results in a circular logic where the guard is optimized under the
+// assumption that it never fails.
+//
+// For example, in the loop below the induction variable will be marked as nuw
+// basing on the guard. Basing on nuw the guard predicate will be considered
+// monotonic. Given a monotonic condition it's tempting to replace the induction
+// variable in the condition with its value on the last iteration. But this
+// transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
+//
+// for (int i = b; i != e; i++)
+// guard(i u< len)
+//
+// One of the ways to reason about this problem is to use an inductive proof
+// approach. Given the loop:
+//
+// if (B(Start)) {
+// do {
+// I = PHI(Start, I.INC)
+// I.INC = I + Step
+// guard(G(I));
+// } while (B(I.INC));
+// }
+//
+// where B(x) and G(x) are predicates that map integers to booleans, we want a
+// loop invariant expression M such the following program has the same semantics
+// as the above:
+//
+// if (B(Start)) {
+// do {
+// I = PHI(Start, I.INC)
+// I.INC = I + Step
+// guard(G(Start) && M);
+// } while (B(I.INC));
+// }
+//
+// One solution for M is M = forall X . (G(X) && B(X + Step)) => G(X + Step)
+//
+// Informal proof that the transformation above is correct:
+//
+// By the definition of guards we can rewrite the guard condition to:
+// G(I) && G(Start) && M
+//
+// Let's prove that for each iteration of the loop:
+// G(Start) && M => G(I)
+// And the condition above can be simplified to G(Start) && M.
+//
+// Induction base.
+// G(Start) && M => G(Start)
+//
+// Induction step. Assuming G(Start) && M => G(I) on the subsequent
+// iteration:
+//
+// B(I + Step) is true because it's the backedge condition.
+// G(I) is true because the backedge is guarded by this condition.
+//
+// So M = forall X . (G(X) && B(X + Step)) => G(X + Step) implies
+// G(I + Step).
+//
+// Note that we can use anything stronger than M, i.e. any condition which
+// implies M.
+//
+// For now the transformation is limited to the following case:
+// * The loop has a single latch with either ult or slt icmp condition.
+// * The step of the IV used in the latch condition is 1.
+// * The IV of the latch condition is the same as the post increment IV of the
+// guard condition.
+// * The guard condition is ult.
+//
+// In this case the latch is of the from:
+// ++i u< latchLimit or ++i s< latchLimit
+// and the guard is of the form:
+// i u< guardLimit
+//
+// For the unsigned latch comparison case M is:
+// forall X . X u< guardLimit && (X + 1) u< latchLimit =>
+// (X + 1) u< guardLimit
+//
+// This is true if latchLimit u<= guardLimit since then
+// (X + 1) u< latchLimit u<= guardLimit == (X + 1) u< guardLimit.
+//
+// So the widened condition is:
+// i.start u< guardLimit && latchLimit u<= guardLimit
+//
+// For the signed latch comparison case M is:
+// forall X . X u< guardLimit && (X + 1) s< latchLimit =>
+// (X + 1) u< guardLimit
+//
+// The only way the antecedent can be true and the consequent can be false is
+// if
+// X == guardLimit - 1
+// (and guardLimit is non-zero, but we won't use this latter fact).
+// If X == guardLimit - 1 then the second half of the antecedent is
+// guardLimit s< latchLimit
+// and its negation is
+// latchLimit s<= guardLimit.
+//
+// In other words, if latchLimit s<= guardLimit then:
+// (the ranges below are written in ConstantRange notation, where [A, B) is the
+// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
+//
+// forall X . X u< guardLimit && (X + 1) s< latchLimit => (X + 1) u< guardLimit
+// == forall X . X u< guardLimit && (X + 1) s< guardLimit => (X + 1) u< guardLimit
+// == forall X . X in [0, guardLimit) && (X + 1) in [INT_MIN, guardLimit) => (X + 1) in [0, guardLimit)
+// == forall X . X in [0, guardLimit) && X in [INT_MAX, guardLimit-1) => X in [-1, guardLimit-1)
+// == forall X . X in [0, guardLimit-1) => X in [-1, guardLimit-1)
+// == true
+//
+// So the widened condition is:
+// i.start u< guardLimit && latchLimit s<= guardLimit
+//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/LoopPredication.h"
@@ -75,8 +189,16 @@ class LoopPredication {
Loop *L;
const DataLayout *DL;
BasicBlock *Preheader;
+ LoopICmp LatchCheck;
- Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
+ Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) {
+ return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0),
+ ICI->getOperand(1));
+ }
+ Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
+ Value *RHS);
+
+ Optional<LoopICmp> parseLoopLatchICmp();
Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder,
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
@@ -135,11 +257,8 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
}
Optional<LoopPredication::LoopICmp>
-LoopPredication::parseLoopICmp(ICmpInst *ICI) {
- ICmpInst::Predicate Pred = ICI->getPredicate();
-
- Value *LHS = ICI->getOperand(0);
- Value *RHS = ICI->getOperand(1);
+LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
+ Value *RHS) {
const SCEV *LHSS = SE->getSCEV(LHS);
if (isa<SCEVCouldNotCompute>(LHSS))
return None;
@@ -165,6 +284,8 @@ Value *LoopPredication::expandCheck(SCEVExpander &Expander,
IRBuilder<> &Builder,
ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS, Instruction *InsertAt) {
+ // TODO: we can check isLoopEntryGuardedByCond before emitting the check
+
Type *Ty = LHS->getType();
assert(Ty == RHS->getType() && "expandCheck operands have different types?");
Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt);
@@ -181,51 +302,54 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
DEBUG(ICI->dump());
+ // parseLoopStructure guarantees that the latch condition is:
+ // ++i u< latchLimit or ++i s< latchLimit
+ // We are looking for the range checks of the form:
+ // i u< guardLimit
auto RangeCheck = parseLoopICmp(ICI);
if (!RangeCheck) {
DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
return None;
}
-
- ICmpInst::Predicate Pred = RangeCheck->Pred;
- const SCEVAddRecExpr *IndexAR = RangeCheck->IV;
- const SCEV *RHSS = RangeCheck->Limit;
-
- auto CanExpand = [this](const SCEV *S) {
- return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
- };
- if (!CanExpand(RHSS))
+ if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
+ DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred
+ << ")!\n");
return None;
-
- DEBUG(dbgs() << "IndexAR: ");
- DEBUG(IndexAR->dump());
-
- bool IsIncreasing = false;
- if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing))
- return None;
-
- // If the predicate is increasing the condition can change from false to true
- // as the loop progresses, in this case take the value on the first iteration
- // for the widened check. Otherwise the condition can change from true to
- // false as the loop progresses, so take the value on the last iteration.
- const SCEV *NewLHSS = IsIncreasing
- ? IndexAR->getStart()
- : SE->getSCEVAtScope(IndexAR, L->getParentLoop());
- if (NewLHSS == IndexAR) {
- DEBUG(dbgs() << "Can't compute NewLHSS!\n");
+ }
+ auto *RangeCheckIV = RangeCheck->IV;
+ auto *PostIncRangeCheckIV = RangeCheckIV->getPostIncExpr(*SE);
+ if (LatchCheck.IV != PostIncRangeCheckIV) {
+ DEBUG(dbgs() << "Post increment range check IV (" << *PostIncRangeCheckIV
+ << ") is not the same as latch IV (" << *LatchCheck.IV
+ << ")!\n");
return None;
}
+ assert(RangeCheckIV->getStepRecurrence(*SE)->isOne() && "must be one");
+ const SCEV *Start = RangeCheckIV->getStart();
- DEBUG(dbgs() << "NewLHSS: ");
- DEBUG(NewLHSS->dump());
+ // Generate the widened condition. See the file header comment for reasoning.
+ // If the latch condition is unsigned:
+ // i.start u< guardLimit && latchLimit u<= guardLimit
+ // If the latch condition is signed:
+ // i.start u< guardLimit && latchLimit s<= guardLimit
- if (!CanExpand(NewLHSS))
- return None;
+ auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred)
+ ? ICmpInst::ICMP_SLE
+ : ICmpInst::ICMP_ULE;
- DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n");
+ auto CanExpand = [this](const SCEV *S) {
+ return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
+ };
+ if (!CanExpand(Start) || !CanExpand(LatchCheck.Limit) ||
+ !CanExpand(RangeCheck->Limit))
+ return None;
Instruction *InsertAt = Preheader->getTerminator();
- return expandCheck(Expander, Builder, Pred, NewLHSS, RHSS, InsertAt);
+ auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck->Pred,
+ Start, RangeCheck->Limit, InsertAt);
+ auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred,
+ LatchCheck.Limit, RangeCheck->Limit, InsertAt);
+ return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
}
bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
@@ -288,6 +412,59 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
return true;
}
+Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
+ using namespace PatternMatch;
+
+ BasicBlock *LoopLatch = L->getLoopLatch();
+ if (!LoopLatch) {
+ DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
+ return None;
+ }
+
+ ICmpInst::Predicate Pred;
+ Value *LHS, *RHS;
+ BasicBlock *TrueDest, *FalseDest;
+
+ if (!match(LoopLatch->getTerminator(),
+ m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest,
+ FalseDest))) {
+ DEBUG(dbgs() << "Failed to match the latch terminator!\n");
+ return None;
+ }
+ assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) &&
+ "One of the latch's destinations must be the header");
+ if (TrueDest != L->getHeader())
+ Pred = ICmpInst::getInversePredicate(Pred);
+
+ auto Result = parseLoopICmp(Pred, LHS, RHS);
+ if (!Result) {
+ DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
+ return None;
+ }
+
+ if (Result->Pred != ICmpInst::ICMP_ULT &&
+ Result->Pred != ICmpInst::ICMP_SLT) {
+ DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
+ << ")!\n");
+ return None;
+ }
+
+ // Check affine first, so if it's not we don't try to compute the step
+ // recurrence.
+ if (!Result->IV->isAffine()) {
+ DEBUG(dbgs() << "The induction variable is not affine!\n");
+ return None;
+ }
+
+ auto *Step = Result->IV->getStepRecurrence(*SE);
+ if (!Step->isOne()) {
+ DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
+ return None;
+ }
+
+ return Result;
+}
+
bool LoopPredication::runOnLoop(Loop *Loop) {
L = Loop;
@@ -308,6 +485,11 @@ bool LoopPredication::runOnLoop(Loop *Loop) {
if (!Preheader)
return false;
+ auto LatchCheckOpt = parseLoopLatchICmp();
+ if (!LatchCheckOpt)
+ return false;
+ LatchCheck = *LatchCheckOpt;
+
// Collect all the guards into a vector and process later, so as not
// to invalidate the instruction iterator.
SmallVector<IntrinsicInst *, 4> Guards;
OpenPOWER on IntegriCloud