summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolutionExpander.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolutionExpander.cpp68
1 files changed, 68 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
index 86cfe2d00ab..49cdcf75818 100644
--- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp
@@ -1971,6 +1971,10 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
case SCEVPredicate::P_Equal:
return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
+ case SCEVPredicate::P_Wrap: {
+ auto *AddRecPred = cast<SCEVWrapPredicate>(Pred);
+ return expandWrapPredicate(AddRecPred, IP);
+ }
}
llvm_unreachable("Unknown SCEV predicate type");
}
@@ -1985,6 +1989,70 @@ Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
return I;
}
+Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
+ Instruction *Loc, bool Signed) {
+ assert(AR->isAffine() && "Cannot generate RT check for "
+ "non-affine expression");
+
+ const SCEV *ExitCount = SE.getBackedgeTakenCount(AR->getLoop());
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *Start = AR->getStart();
+
+ unsigned DstBits = SE.getTypeSizeInBits(AR->getType());
+ unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType());
+ unsigned MaxBits = 2 * std::max(DstBits, SrcBits);
+
+ auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType());
+ IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits);
+
+ assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count");
+
+ const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy);
+ const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy);
+ const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy)
+ : SE.getZeroExtendExpr(Start, MaxTy);
+
+ const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step));
+ const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy)
+ : SE.getZeroExtendExpr(End, MaxTy);
+
+ const SCEV *LHS = SE.getAddExpr(
+ ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep));
+
+ // Do all SCEV expansions now.
+ Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc);
+ Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc);
+
+ Builder.SetInsertPoint(Loc);
+
+ return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal);
+}
+
+Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred,
+ Instruction *IP) {
+ const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr());
+ Value *NSSWCheck = nullptr, *NUSWCheck = nullptr;
+
+ // Add a check for NUSW
+ if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW)
+ NUSWCheck = generateOverflowCheck(A, IP, false);
+
+ // Add a check for NSSW
+ if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW)
+ NSSWCheck = generateOverflowCheck(A, IP, true);
+
+ if (NUSWCheck && NSSWCheck)
+ return Builder.CreateOr(NUSWCheck, NSSWCheck);
+
+ if (NUSWCheck)
+ return NUSWCheck;
+
+ if (NSSWCheck)
+ return NSSWCheck;
+
+ return ConstantInt::getFalse(IP->getContext());
+}
+
Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
Instruction *IP) {
auto *BoolType = IntegerType::get(IP->getContext(), 1);
OpenPOWER on IntegriCloud