diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp | 55 |
1 files changed, 44 insertions, 11 deletions
diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 99596df6051..2c582ca90f6 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -115,6 +115,11 @@ static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", cl::Hidden, cl::init(true)); +static cl::opt<bool> AllowNarrowLatchCondition( + "irce-allow-narrow-latch", cl::Hidden, cl::init(false), + cl::desc("If set to true, IRCE may eliminate wide range checks in loops " + "with narrow latch condition.")); + static const char *ClonedLoopTag = "irce.loop.clone"; #define DEBUG_TYPE "irce" @@ -1044,11 +1049,23 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return Result; } +/// If the type of \p S matches with \p Ty, return \p S. Otherwise, return +/// signed or unsigned extension of \p S to type \p Ty. +static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, + bool Signed) { + return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty); +} + Optional<LoopConstrainer::SubRanges> LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); - if (Range.getType() != Ty) + auto *RTy = cast<IntegerType>(Range.getType()); + + // We only support wide range checks and narrow latches. + if (!AllowNarrowLatchCondition && RTy != Ty) + return None; + if (RTy->getBitWidth() < Ty->getBitWidth()) return None; LoopConstrainer::SubRanges Result; @@ -1056,8 +1073,10 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { // I think we can be more aggressive here and make this nuw / nsw if the // addition that feeds into the icmp for the latch's terminating branch is nuw // / nsw. In any case, a wrapping 2's complement addition is safe. - const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); - const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); + const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart), + RTy, SE, IsSignedPredicate); + const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy, + SE, IsSignedPredicate); bool Increasing = MainLoopStructure.IndVarIncreasing; @@ -1067,7 +1086,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; - const SCEV *One = SE.getOne(Ty); + const SCEV *One = SE.getOne(RTy); if (Increasing) { Smallest = Start; Greatest = End; @@ -1256,6 +1275,13 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( bool IsSignedPredicate = LS.IsSignedPredicate; IRBuilder<> B(PreheaderJump); + auto *RangeTy = Range.getBegin()->getType(); + auto NoopOrExt = [&](Value *V) { + if (V->getType() == RangeTy) + return V; + return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) + : B.CreateZExt(V, RangeTy, "wide." + V->getName()); + }; // EnterLoopCond - is it okay to start executing this `LS'? Value *EnterLoopCond = nullptr; @@ -1263,9 +1289,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( Increasing ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); - Value *IndVarStart = LS.IndVarStart; - Value *IndVarBase = LS.IndVarBase; - Value *LoopExitAt = LS.LoopExitAt; + Value *IndVarStart = NoopOrExt(LS.IndVarStart); EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); @@ -1273,6 +1297,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); + Value *IndVarBase = NoopOrExt(LS.IndVarBase); Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); Value *CondForBranch = LS.LatchBrExitIdx == 1 @@ -1286,6 +1311,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. + Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); @@ -1394,7 +1420,7 @@ bool LoopConstrainer::run() { SubRanges SR = MaybeSR.getValue(); bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = - cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); + cast<IntegerType>(Range.getBegin()->getType()); SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); Instruction *InsertPt = OriginalPreheader->getTerminator(); @@ -1557,6 +1583,12 @@ Optional<InductiveRangeCheck::Range> InductiveRangeCheck::computeSafeIterationSpace( ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, bool IsLatchSigned) const { + // We can deal when types of latch check and range checks don't match in case + // if latch check is more narrow. + auto *IVType = cast<IntegerType>(IndVar->getType()); + auto *RCType = cast<IntegerType>(getBegin()->getType()); + if (IVType->getBitWidth() > RCType->getBitWidth()) + return None; // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1580,8 +1612,9 @@ InductiveRangeCheck::computeSafeIterationSpace( if (!IndVar->isAffine()) return None; - const SCEV *A = IndVar->getStart(); - const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); + const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned); + const SCEVConstant *B = dyn_cast<SCEVConstant>( + NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned)); if (!B) return None; assert(!B->isZero() && "Recurrence with zero step?"); @@ -1592,7 +1625,7 @@ InductiveRangeCheck::computeSafeIterationSpace( return None; assert(!D->getValue()->isZero() && "Recurrence with zero step?"); - unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); + unsigned BitWidth = RCType->getBitWidth(); const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); // Subtract Y from X so that it does not go through border of the IV |

