diff options
author | Silviu Baranga <silviu.baranga@arm.com> | 2016-04-25 09:27:16 +0000 |
---|---|---|
committer | Silviu Baranga <silviu.baranga@arm.com> | 2016-04-25 09:27:16 +0000 |
commit | 795c629ec93cfa4da558df3231c1309fbe6883be (patch) | |
tree | b41c451865a0d8d0698edc33c905f0d09c9573bf /llvm/lib/Analysis/ScalarEvolutionExpander.cpp | |
parent | a44d44cb2ea8264c905d65e8bb1c94f3abbfac3d (diff) | |
download | bcm5719-llvm-795c629ec93cfa4da558df3231c1309fbe6883be.tar.gz bcm5719-llvm-795c629ec93cfa4da558df3231c1309fbe6883be.zip |
[SCEV] Improve the run-time checking of the NoWrap predicate
Summary:
This implements a new method of run-time checking the NoWrap
SCEV predicates, which should be easier to optimize and nicer
for targets that don't correctly handle multiplication/addition
of large integer types (like i128).
If the AddRec is {a,+,b} and the backedge taken count is c,
the idea is to check that |b| * c doesn't have unsigned overflow,
and depending on the sign of b, that:
a + |b| * c >= a (b >= 0) or
a - |b| * c <= a (b <= 0)
where the comparisons above are signed or unsigned, depending on
the flag that we're checking.
The advantage of doing this is that we avoid extending to a larger
type and we avoid the multiplication of large types (multiplying
i128 can be expensive).
Reviewers: sanjoy
Subscribers: llvm-commits, mzolotukhin
Differential Revision: http://reviews.llvm.org/D19266
llvm-svn: 267389
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolutionExpander.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolutionExpander.cpp | 82 |
1 files changed, 63 insertions, 19 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index 5f93768bd2c..83abb8f32a6 100644 --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -2007,37 +2007,81 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, SCEVUnionPredicate Pred; const SCEV *ExitCount = SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + const SCEV *Step = AR->getStepRecurrence(SE); const SCEV *Start = AR->getStart(); - unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); - unsigned MaxBits = 2 * std::max(DstBits, SrcBits); - - auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType()); - IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits); + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); - assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + // The expression {Start,+,Step} has nusw/nssw if + // Step < 0, Start - |Step| * Backedge <= Start + // Step >= 0, Start + |Step| * Backedge > Start + // and |Step| * Backedge doesn't unsigned overflow. - const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy); - const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy); - const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy) - : SE.getZeroExtendExpr(Start, MaxTy); + IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); + Builder.SetInsertPoint(Loc); + Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); - const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step)); - const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy) - : SE.getZeroExtendExpr(End, MaxTy); + IntegerType *Ty = + IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType())); - const SCEV *LHS = SE.getAddExpr( - ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep)); + Value *StepValue = expandCodeFor(Step, Ty, Loc); + Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); + Value *StartValue = expandCodeFor(Start, Ty, Loc); - // Do all SCEV expansions now. - Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc); - Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc); + ConstantInt *Zero = + ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); Builder.SetInsertPoint(Loc); + // Compute |Step| + Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero); + Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue); + + // Get the backedge taken count and truncate or extended to the AR type. + Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); + auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), + Intrinsic::umul_with_overflow, Ty); + + // Compute |Step| * Backedge + CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); + Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); + Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + + // Compute: + // Start + |Step| * Backedge < Start + // Start - |Step| * Backedge > Start + Value *Add = Builder.CreateAdd(StartValue, MulV); + Value *Sub = Builder.CreateSub(StartValue, MulV); + + Value *EndCompareGT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); + + Value *EndCompareLT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); + + // Select the answer based on the sign of Step. + Value *EndCheck = + Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); + + // If the backedge taken count type is larger than the AR type, + // check that we don't drop any bits by truncating it. If we are + // droping bits, then we have overflow (unless the step is zero). + if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { + auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); + auto *BackedgeCheck = + Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, + ConstantInt::get(Loc->getContext(), MaxVal)); + BackedgeCheck = Builder.CreateAnd( + BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); + + EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); + } - return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal); + EndCheck = Builder.CreateOr(EndCheck, OfMul); + return EndCheck; } Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, |