diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolutionExpander.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolutionExpander.cpp | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index 86cfe2d00ab..49cdcf75818 100644 --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1971,6 +1971,10 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP); case SCEVPredicate::P_Equal: return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP); + case SCEVPredicate::P_Wrap: { + auto *AddRecPred = cast<SCEVWrapPredicate>(Pred); + return expandWrapPredicate(AddRecPred, IP); + } } llvm_unreachable("Unknown SCEV predicate type"); } @@ -1985,6 +1989,70 @@ Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, return I; } +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, bool Signed) { + assert(AR->isAffine() && "Cannot generate RT check for " + "non-affine expression"); + + const SCEV *ExitCount = SE.getBackedgeTakenCount(AR->getLoop()); + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *Start = AR->getStart(); + + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); + unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); + unsigned MaxBits = 2 * std::max(DstBits, SrcBits); + + auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType()); + IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + + const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy); + const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy); + const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy) + : SE.getZeroExtendExpr(Start, MaxTy); + + const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step)); + const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy) + : SE.getZeroExtendExpr(End, MaxTy); + + const SCEV *LHS = SE.getAddExpr( + ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep)); + + // Do all SCEV expansions now. + Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc); + Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc); + + Builder.SetInsertPoint(Loc); + + return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal); +} + +Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, + Instruction *IP) { + const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr()); + Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; + + // Add a check for NUSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) + NUSWCheck = generateOverflowCheck(A, IP, false); + + // Add a check for NSSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) + NSSWCheck = generateOverflowCheck(A, IP, true); + + if (NUSWCheck && NSSWCheck) + return Builder.CreateOr(NUSWCheck, NSSWCheck); + + if (NUSWCheck) + return NUSWCheck; + + if (NSSWCheck) + return NSSWCheck; + + return ConstantInt::getFalse(IP->getContext()); +} + Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { auto *BoolType = IntegerType::get(IP->getContext(), 1); |