diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp | 55 | 
1 files changed, 44 insertions, 11 deletions
| diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 99596df6051..2c582ca90f6 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -115,6 +115,11 @@ static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks",  static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch",                                                   cl::Hidden, cl::init(true)); +static cl::opt<bool> AllowNarrowLatchCondition( +    "irce-allow-narrow-latch", cl::Hidden, cl::init(false), +    cl::desc("If set to true, IRCE may eliminate wide range checks in loops " +             "with narrow latch condition.")); +  static const char *ClonedLoopTag = "irce.loop.clone";  #define DEBUG_TYPE "irce" @@ -1044,11 +1049,23 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,    return Result;  } +/// If the type of \p S matches with \p Ty, return \p S. Otherwise, return +/// signed or unsigned extension of \p S to type \p Ty. +static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, +                                bool Signed) { +  return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty); +} +  Optional<LoopConstrainer::SubRanges>  LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {    IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); -  if (Range.getType() != Ty) +  auto *RTy = cast<IntegerType>(Range.getType()); + +  // We only support wide range checks and narrow latches. +  if (!AllowNarrowLatchCondition && RTy != Ty) +    return None; +  if (RTy->getBitWidth() < Ty->getBitWidth())      return None;    LoopConstrainer::SubRanges Result; @@ -1056,8 +1073,10 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {    // I think we can be more aggressive here and make this nuw / nsw if the    // addition that feeds into the icmp for the latch's terminating branch is nuw    // / nsw.  In any case, a wrapping 2's complement addition is safe. -  const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); -  const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); +  const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart), +                                   RTy, SE, IsSignedPredicate); +  const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy, +                                 SE, IsSignedPredicate);    bool Increasing = MainLoopStructure.IndVarIncreasing; @@ -1067,7 +1086,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {    const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; -  const SCEV *One = SE.getOne(Ty); +  const SCEV *One = SE.getOne(RTy);    if (Increasing) {      Smallest = Start;      Greatest = End; @@ -1256,6 +1275,13 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(    bool IsSignedPredicate = LS.IsSignedPredicate;    IRBuilder<> B(PreheaderJump); +  auto *RangeTy = Range.getBegin()->getType(); +  auto NoopOrExt = [&](Value *V) { +    if (V->getType() == RangeTy) +      return V; +    return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) +                             : B.CreateZExt(V, RangeTy, "wide." + V->getName()); +  };    // EnterLoopCond - is it okay to start executing this `LS'?    Value *EnterLoopCond = nullptr; @@ -1263,9 +1289,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(        Increasing            ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)            : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); -  Value *IndVarStart = LS.IndVarStart; -  Value *IndVarBase = LS.IndVarBase; -  Value *LoopExitAt = LS.LoopExitAt; +  Value *IndVarStart = NoopOrExt(LS.IndVarStart);    EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);    B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); @@ -1273,6 +1297,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(    LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);    B.SetInsertPoint(LS.LatchBr); +  Value *IndVarBase = NoopOrExt(LS.IndVarBase);    Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);    Value *CondForBranch = LS.LatchBrExitIdx == 1 @@ -1286,6 +1311,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(    // IterationsLeft - are there any more iterations left, given the original    // upper bound on the induction variable?  If not, we branch to the "real"    // exit. +  Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);    Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);    B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); @@ -1394,7 +1420,7 @@ bool LoopConstrainer::run() {    SubRanges SR = MaybeSR.getValue();    bool Increasing = MainLoopStructure.IndVarIncreasing;    IntegerType *IVTy = -      cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); +      cast<IntegerType>(Range.getBegin()->getType());    SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");    Instruction *InsertPt = OriginalPreheader->getTerminator(); @@ -1557,6 +1583,12 @@ Optional<InductiveRangeCheck::Range>  InductiveRangeCheck::computeSafeIterationSpace(      ScalarEvolution &SE, const SCEVAddRecExpr *IndVar,      bool IsLatchSigned) const { +  // We can deal when types of latch check and range checks don't match in case +  // if latch check is more narrow. +  auto *IVType = cast<IntegerType>(IndVar->getType()); +  auto *RCType = cast<IntegerType>(getBegin()->getType()); +  if (IVType->getBitWidth() > RCType->getBitWidth()) +    return None;    // IndVar is of the form "A + B * I" (where "I" is the canonical induction    // variable, that may or may not exist as a real llvm::Value in the loop) and    // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1580,8 +1612,9 @@ InductiveRangeCheck::computeSafeIterationSpace(    if (!IndVar->isAffine())      return None; -  const SCEV *A = IndVar->getStart(); -  const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); +  const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned); +  const SCEVConstant *B = dyn_cast<SCEVConstant>( +      NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned));    if (!B)      return None;    assert(!B->isZero() && "Recurrence with zero step?"); @@ -1592,7 +1625,7 @@ InductiveRangeCheck::computeSafeIterationSpace(      return None;    assert(!D->getValue()->isZero() && "Recurrence with zero step?"); -  unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); +  unsigned BitWidth = RCType->getBitWidth();    const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));    // Subtract Y from X so that it does not go through border of the IV | 

