summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
diff options
context:
space:
mode:
authorSilviu Baranga <silviu.baranga@arm.com>2016-04-25 09:27:16 +0000
committerSilviu Baranga <silviu.baranga@arm.com>2016-04-25 09:27:16 +0000
commit795c629ec93cfa4da558df3231c1309fbe6883be (patch)
treeb41c451865a0d8d0698edc33c905f0d09c9573bf /llvm/lib/Analysis/ScalarEvolutionExpander.cpp
parenta44d44cb2ea8264c905d65e8bb1c94f3abbfac3d (diff)
downloadbcm5719-llvm-795c629ec93cfa4da558df3231c1309fbe6883be.tar.gz
bcm5719-llvm-795c629ec93cfa4da558df3231c1309fbe6883be.zip
[SCEV] Improve the run-time checking of the NoWrap predicate
Summary: This implements a new method of run-time checking the NoWrap SCEV predicates, which should be easier to optimize and nicer for targets that don't correctly handle multiplication/addition of large integer types (like i128). If the AddRec is {a,+,b} and the backedge taken count is c, the idea is to check that |b| * c doesn't have unsigned overflow, and depending on the sign of b, that: a + |b| * c >= a (b >= 0) or a - |b| * c <= a (b <= 0) where the comparisons above are signed or unsigned, depending on the flag that we're checking. The advantage of doing this is that we avoid extending to a larger type and we avoid the multiplication of large types (multiplying i128 can be expensive). Reviewers: sanjoy Subscribers: llvm-commits, mzolotukhin Differential Revision: http://reviews.llvm.org/D19266 llvm-svn: 267389
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolutionExpander.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolutionExpander.cpp82
1 files changed, 63 insertions, 19 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
index 5f93768bd2c..83abb8f32a6 100644
--- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
@@ -2007,37 +2007,81 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
SCEVUnionPredicate Pred;
const SCEV *ExitCount =
SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred);
+
+ assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
+
const SCEV *Step = AR->getStepRecurrence(SE);
const SCEV *Start = AR->getStart();
- unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType());
- unsigned MaxBits = 2 * std::max(DstBits, SrcBits);
-
- auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType());
- IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits);
+ unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
- assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
+ // The expression {Start,+,Step} has nusw/nssw if
+ // Step < 0, Start - |Step| * Backedge <= Start
+ // Step >= 0, Start + |Step| * Backedge > Start
+ // and |Step| * Backedge doesn't unsigned overflow.
- const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy);
- const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy);
- const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy)
- : SE.getZeroExtendExpr(Start, MaxTy);
+ IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits);
+ Builder.SetInsertPoint(Loc);
+ Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc);
- const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step));
- const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy)
- : SE.getZeroExtendExpr(End, MaxTy);
+ IntegerType *Ty =
+ IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType()));
- const SCEV *LHS = SE.getAddExpr(
- ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep));
+ Value *StepValue = expandCodeFor(Step, Ty, Loc);
+ Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc);
+ Value *StartValue = expandCodeFor(Start, Ty, Loc);
- // Do all SCEV expansions now.
- Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc);
- Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc);
+ ConstantInt *Zero =
+ ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
Builder.SetInsertPoint(Loc);
+ // Compute |Step|
+ Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero);
+ Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue);
+
+ // Get the backedge taken count and truncate or extended to the AR type.
+ Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty);
+ auto *MulF = Intrinsic::getDeclaration(Loc->getModule(),
+ Intrinsic::umul_with_overflow, Ty);
+
+ // Compute |Step| * Backedge
+ CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul");
+ Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result");
+ Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow");
+
+ // Compute:
+ // Start + |Step| * Backedge < Start
+ // Start - |Step| * Backedge > Start
+ Value *Add = Builder.CreateAdd(StartValue, MulV);
+ Value *Sub = Builder.CreateSub(StartValue, MulV);
+
+ Value *EndCompareGT = Builder.CreateICmp(
+ Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue);
+
+ Value *EndCompareLT = Builder.CreateICmp(
+ Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue);
+
+ // Select the answer based on the sign of Step.
+ Value *EndCheck =
+ Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT);
+
+ // If the backedge taken count type is larger than the AR type,
+ // check that we don't drop any bits by truncating it. If we are
+ // droping bits, then we have overflow (unless the step is zero).
+ if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) {
+ auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits);
+ auto *BackedgeCheck =
+ Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal,
+ ConstantInt::get(Loc->getContext(), MaxVal));
+ BackedgeCheck = Builder.CreateAnd(
+ BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero));
+
+ EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck);
+ }
- return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal);
+ EndCheck = Builder.CreateOr(EndCheck, OfMul);
+ return EndCheck;
}
Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred,
OpenPOWER on IntegriCloud