diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopPredication.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LoopPredication.cpp | 258 |
1 files changed, 220 insertions, 38 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 9b12ba18044..84577dd182a 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -34,6 +34,120 @@ // else // deoptimize // +// It's tempting to rely on SCEV here, but it has proven to be problematic. +// Generally the facts SCEV provides about the increment step of add +// recurrences are true if the backedge of the loop is taken, which implicitly +// assumes that the guard doesn't fail. Using these facts to optimize the +// guard results in a circular logic where the guard is optimized under the +// assumption that it never fails. +// +// For example, in the loop below the induction variable will be marked as nuw +// basing on the guard. Basing on nuw the guard predicate will be considered +// monotonic. Given a monotonic condition it's tempting to replace the induction +// variable in the condition with its value on the last iteration. But this +// transformation is not correct, e.g. e = 4, b = 5 breaks the loop. +// +// for (int i = b; i != e; i++) +// guard(i u< len) +// +// One of the ways to reason about this problem is to use an inductive proof +// approach. Given the loop: +// +// if (B(Start)) { +// do { +// I = PHI(Start, I.INC) +// I.INC = I + Step +// guard(G(I)); +// } while (B(I.INC)); +// } +// +// where B(x) and G(x) are predicates that map integers to booleans, we want a +// loop invariant expression M such the following program has the same semantics +// as the above: +// +// if (B(Start)) { +// do { +// I = PHI(Start, I.INC) +// I.INC = I + Step +// guard(G(Start) && M); +// } while (B(I.INC)); +// } +// +// One solution for M is M = forall X . (G(X) && B(X + Step)) => G(X + Step) +// +// Informal proof that the transformation above is correct: +// +// By the definition of guards we can rewrite the guard condition to: +// G(I) && G(Start) && M +// +// Let's prove that for each iteration of the loop: +// G(Start) && M => G(I) +// And the condition above can be simplified to G(Start) && M. +// +// Induction base. +// G(Start) && M => G(Start) +// +// Induction step. Assuming G(Start) && M => G(I) on the subsequent +// iteration: +// +// B(I + Step) is true because it's the backedge condition. +// G(I) is true because the backedge is guarded by this condition. +// +// So M = forall X . (G(X) && B(X + Step)) => G(X + Step) implies +// G(I + Step). +// +// Note that we can use anything stronger than M, i.e. any condition which +// implies M. +// +// For now the transformation is limited to the following case: +// * The loop has a single latch with either ult or slt icmp condition. +// * The step of the IV used in the latch condition is 1. +// * The IV of the latch condition is the same as the post increment IV of the +// guard condition. +// * The guard condition is ult. +// +// In this case the latch is of the from: +// ++i u< latchLimit or ++i s< latchLimit +// and the guard is of the form: +// i u< guardLimit +// +// For the unsigned latch comparison case M is: +// forall X . X u< guardLimit && (X + 1) u< latchLimit => +// (X + 1) u< guardLimit +// +// This is true if latchLimit u<= guardLimit since then +// (X + 1) u< latchLimit u<= guardLimit == (X + 1) u< guardLimit. +// +// So the widened condition is: +// i.start u< guardLimit && latchLimit u<= guardLimit +// +// For the signed latch comparison case M is: +// forall X . X u< guardLimit && (X + 1) s< latchLimit => +// (X + 1) u< guardLimit +// +// The only way the antecedent can be true and the consequent can be false is +// if +// X == guardLimit - 1 +// (and guardLimit is non-zero, but we won't use this latter fact). +// If X == guardLimit - 1 then the second half of the antecedent is +// guardLimit s< latchLimit +// and its negation is +// latchLimit s<= guardLimit. +// +// In other words, if latchLimit s<= guardLimit then: +// (the ranges below are written in ConstantRange notation, where [A, B) is the +// set for (I = A; I != B; I++ /*maywrap*/) yield(I);) +// +// forall X . X u< guardLimit && (X + 1) s< latchLimit => (X + 1) u< guardLimit +// == forall X . X u< guardLimit && (X + 1) s< guardLimit => (X + 1) u< guardLimit +// == forall X . X in [0, guardLimit) && (X + 1) in [INT_MIN, guardLimit) => (X + 1) in [0, guardLimit) +// == forall X . X in [0, guardLimit) && X in [INT_MAX, guardLimit-1) => X in [-1, guardLimit-1) +// == forall X . X in [0, guardLimit-1) => X in [-1, guardLimit-1) +// == true +// +// So the widened condition is: +// i.start u< guardLimit && latchLimit s<= guardLimit +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" @@ -75,8 +189,16 @@ class LoopPredication { Loop *L; const DataLayout *DL; BasicBlock *Preheader; + LoopICmp LatchCheck; - Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); + Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) { + return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0), + ICI->getOperand(1)); + } + Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, + Value *RHS); + + Optional<LoopICmp> parseLoopLatchICmp(); Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -135,11 +257,8 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, } Optional<LoopPredication::LoopICmp> -LoopPredication::parseLoopICmp(ICmpInst *ICI) { - ICmpInst::Predicate Pred = ICI->getPredicate(); - - Value *LHS = ICI->getOperand(0); - Value *RHS = ICI->getOperand(1); +LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, + Value *RHS) { const SCEV *LHSS = SE->getSCEV(LHS); if (isa<SCEVCouldNotCompute>(LHSS)) return None; @@ -165,6 +284,8 @@ Value *LoopPredication::expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Instruction *InsertAt) { + // TODO: we can check isLoopEntryGuardedByCond before emitting the check + Type *Ty = LHS->getType(); assert(Ty == RHS->getType() && "expandCheck operands have different types?"); Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); @@ -181,51 +302,54 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); DEBUG(ICI->dump()); + // parseLoopStructure guarantees that the latch condition is: + // ++i u< latchLimit or ++i s< latchLimit + // We are looking for the range checks of the form: + // i u< guardLimit auto RangeCheck = parseLoopICmp(ICI); if (!RangeCheck) { DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } - - ICmpInst::Predicate Pred = RangeCheck->Pred; - const SCEVAddRecExpr *IndexAR = RangeCheck->IV; - const SCEV *RHSS = RangeCheck->Limit; - - auto CanExpand = [this](const SCEV *S) { - return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); - }; - if (!CanExpand(RHSS)) + if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { + DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred + << ")!\n"); return None; - - DEBUG(dbgs() << "IndexAR: "); - DEBUG(IndexAR->dump()); - - bool IsIncreasing = false; - if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing)) - return None; - - // If the predicate is increasing the condition can change from false to true - // as the loop progresses, in this case take the value on the first iteration - // for the widened check. Otherwise the condition can change from true to - // false as the loop progresses, so take the value on the last iteration. - const SCEV *NewLHSS = IsIncreasing - ? IndexAR->getStart() - : SE->getSCEVAtScope(IndexAR, L->getParentLoop()); - if (NewLHSS == IndexAR) { - DEBUG(dbgs() << "Can't compute NewLHSS!\n"); + } + auto *RangeCheckIV = RangeCheck->IV; + auto *PostIncRangeCheckIV = RangeCheckIV->getPostIncExpr(*SE); + if (LatchCheck.IV != PostIncRangeCheckIV) { + DEBUG(dbgs() << "Post increment range check IV (" << *PostIncRangeCheckIV + << ") is not the same as latch IV (" << *LatchCheck.IV + << ")!\n"); return None; } + assert(RangeCheckIV->getStepRecurrence(*SE)->isOne() && "must be one"); + const SCEV *Start = RangeCheckIV->getStart(); - DEBUG(dbgs() << "NewLHSS: "); - DEBUG(NewLHSS->dump()); + // Generate the widened condition. See the file header comment for reasoning. + // If the latch condition is unsigned: + // i.start u< guardLimit && latchLimit u<= guardLimit + // If the latch condition is signed: + // i.start u< guardLimit && latchLimit s<= guardLimit - if (!CanExpand(NewLHSS)) - return None; + auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred) + ? ICmpInst::ICMP_SLE + : ICmpInst::ICMP_ULE; - DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n"); + auto CanExpand = [this](const SCEV *S) { + return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); + }; + if (!CanExpand(Start) || !CanExpand(LatchCheck.Limit) || + !CanExpand(RangeCheck->Limit)) + return None; Instruction *InsertAt = Preheader->getTerminator(); - return expandCheck(Expander, Builder, Pred, NewLHSS, RHSS, InsertAt); + auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck->Pred, + Start, RangeCheck->Limit, InsertAt); + auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, + LatchCheck.Limit, RangeCheck->Limit, InsertAt); + return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, @@ -288,6 +412,59 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, return true; } +Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { + using namespace PatternMatch; + + BasicBlock *LoopLatch = L->getLoopLatch(); + if (!LoopLatch) { + DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); + return None; + } + + ICmpInst::Predicate Pred; + Value *LHS, *RHS; + BasicBlock *TrueDest, *FalseDest; + + if (!match(LoopLatch->getTerminator(), + m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, + FalseDest))) { + DEBUG(dbgs() << "Failed to match the latch terminator!\n"); + return None; + } + assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && + "One of the latch's destinations must be the header"); + if (TrueDest != L->getHeader()) + Pred = ICmpInst::getInversePredicate(Pred); + + auto Result = parseLoopICmp(Pred, LHS, RHS); + if (!Result) { + DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + return None; + } + + if (Result->Pred != ICmpInst::ICMP_ULT && + Result->Pred != ICmpInst::ICMP_SLT) { + DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred + << ")!\n"); + return None; + } + + // Check affine first, so if it's not we don't try to compute the step + // recurrence. + if (!Result->IV->isAffine()) { + DEBUG(dbgs() << "The induction variable is not affine!\n"); + return None; + } + + auto *Step = Result->IV->getStepRecurrence(*SE); + if (!Step->isOne()) { + DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); + return None; + } + + return Result; +} + bool LoopPredication::runOnLoop(Loop *Loop) { L = Loop; @@ -308,6 +485,11 @@ bool LoopPredication::runOnLoop(Loop *Loop) { if (!Preheader) return false; + auto LatchCheckOpt = parseLoopLatchICmp(); + if (!LatchCheckOpt) + return false; + LatchCheck = *LatchCheckOpt; + // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector<IntrinsicInst *, 4> Guards; |

