summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/ScalarEvolutionExpander.cpp82
1 files changed, 63 insertions, 19 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
index 5f93768bd2c..83abb8f32a6 100644
--- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
@@ -2007,37 +2007,81 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
SCEVUnionPredicate Pred;
const SCEV *ExitCount =
SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred);
+
+ assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
+
const SCEV *Step = AR->getStepRecurrence(SE);
const SCEV *Start = AR->getStart();
- unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType());
- unsigned MaxBits = 2 * std::max(DstBits, SrcBits);
-
- auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType());
- IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits);
+ unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
- assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
+ // The expression {Start,+,Step} has nusw/nssw if
+ // Step < 0, Start - |Step| * Backedge <= Start
+ // Step >= 0, Start + |Step| * Backedge > Start
+ // and |Step| * Backedge doesn't unsigned overflow.
- const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy);
- const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy);
- const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy)
- : SE.getZeroExtendExpr(Start, MaxTy);
+ IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
+ Builder.SetInsertPoint(Loc);
+ Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);
- const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step));
- const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy)
- : SE.getZeroExtendExpr(End, MaxTy);
+ IntegerType *Ty =
+ IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType()));
- const SCEV *LHS = SE.getAddExpr(
- ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep));
+ Value *StepValue = expandCodeFor(Step, Ty, Loc);
+ Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
+ Value *StartValue = expandCodeFor(Start, Ty, Loc);
- // Do all SCEV expansions now.
- Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc);
- Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc);
+ ConstantInt *Zero =
+ ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
Builder.SetInsertPoint(Loc);
+ // Compute |Step|
+ Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero);
+ Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue);
+
+ // Get the backedge taken count and truncate or extended to the AR type.
+ Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty);
+ auto *MulF = Intrinsic::getDeclaration(Loc->getModule(),
+ Intrinsic::umul_with_overflow, Ty);
+
+ // Compute |Step| * Backedge
+ CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul");
+ Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result");
+ Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow");
+
+ // Compute:
+ // Start + |Step| * Backedge < Start
+ // Start - |Step| * Backedge > Start
+ Value *Add = Builder.CreateAdd(StartValue, MulV);
+ Value *Sub = Builder.CreateSub(StartValue, MulV);
+
+ Value *EndCompareGT = Builder.CreateICmp(
+ Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue);
+
+ Value *EndCompareLT = Builder.CreateICmp(
+ Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue);
+
+ // Select the answer based on the sign of Step.
+ Value *EndCheck =
+ Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT);
+
+ // If the backedge taken count type is larger than the AR type,
+ // check that we don't drop any bits by truncating it. If we are
+ // droping bits, then we have overflow (unless the step is zero).
+ if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) {
+ auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits);
+ auto *BackedgeCheck =
+ Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal,
+ ConstantInt::get(Loc->getContext(), MaxVal));
+ BackedgeCheck = Builder.CreateAnd(
+ BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero));
+
+ EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck);
+ }
- return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal);
+ EndCheck = Builder.CreateOr(EndCheck, OfMul);
+ return EndCheck;
}
Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred,
OpenPOWER on IntegriCloud