diff options
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolutionExpander.cpp | 82 |
1 files changed, 63 insertions, 19 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index 5f93768bd2c..83abb8f32a6 100644 --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -2007,37 +2007,81 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, SCEVUnionPredicate Pred; const SCEV *ExitCount = SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + const SCEV *Step = AR->getStepRecurrence(SE); const SCEV *Start = AR->getStart(); - unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); - unsigned MaxBits = 2 * std::max(DstBits, SrcBits); - - auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType()); - IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits); + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); - assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + // The expression {Start,+,Step} has nusw/nssw if + // Step < 0, Start - |Step| * Backedge <= Start + // Step >= 0, Start + |Step| * Backedge > Start + // and |Step| * Backedge doesn't unsigned overflow. - const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy); - const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy); - const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy) - : SE.getZeroExtendExpr(Start, MaxTy); + IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); + Builder.SetInsertPoint(Loc); + Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); - const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step)); - const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy) - : SE.getZeroExtendExpr(End, MaxTy); + IntegerType *Ty = + IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType())); - const SCEV *LHS = SE.getAddExpr( - ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep)); + Value *StepValue = expandCodeFor(Step, Ty, Loc); + Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); + Value *StartValue = expandCodeFor(Start, Ty, Loc); - // Do all SCEV expansions now. - Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc); - Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc); + ConstantInt *Zero = + ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); Builder.SetInsertPoint(Loc); + // Compute |Step| + Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero); + Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue); + + // Get the backedge taken count and truncate or extended to the AR type. + Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); + auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), + Intrinsic::umul_with_overflow, Ty); + + // Compute |Step| * Backedge + CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); + Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); + Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + + // Compute: + // Start + |Step| * Backedge < Start + // Start - |Step| * Backedge > Start + Value *Add = Builder.CreateAdd(StartValue, MulV); + Value *Sub = Builder.CreateSub(StartValue, MulV); + + Value *EndCompareGT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); + + Value *EndCompareLT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); + + // Select the answer based on the sign of Step. + Value *EndCheck = + Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); + + // If the backedge taken count type is larger than the AR type, + // check that we don't drop any bits by truncating it. If we are + // droping bits, then we have overflow (unless the step is zero). + if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { + auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); + auto *BackedgeCheck = + Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, + ConstantInt::get(Loc->getContext(), MaxVal)); + BackedgeCheck = Builder.CreateAnd( + BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); + + EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); + } - return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal); + EndCheck = Builder.CreateOr(EndCheck, OfMul); + return EndCheck; } Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, |