diff options
author | Silviu Baranga <silviu.baranga@arm.com> | 2016-02-08 17:02:45 +0000 |
---|---|---|
committer | Silviu Baranga <silviu.baranga@arm.com> | 2016-02-08 17:02:45 +0000 |
commit | ea63a7f512dc25451b9da06653d966397fa9b758 (patch) | |
tree | 9bf5c93fa4489e9a59bc762c706de0320affbeec /llvm/lib | |
parent | cbec16036b1e03c5e5ba7ab9fde7a0543286c1c9 (diff) | |
download | bcm5719-llvm-ea63a7f512dc25451b9da06653d966397fa9b758.tar.gz bcm5719-llvm-ea63a7f512dc25451b9da06653d966397fa9b758.zip |
[SCEV][LAA] Re-commit r260085 and r260086, this time with a fix for the memory
sanitizer issue. The PredicatedScalarEvolution's copy constructor
wasn't copying the Generation value, and was leaving it un-initialized.
Original commit message:
[SCEV][LAA] Add no wrap SCEV predicates and use use them to improve strided pointer detection
Summary:
This change adds no wrap SCEV predicates with:
- support for runtime checking
- support for expression rewriting:
(sext ({x,+,y}) -> {sext(x),+,sext(y)}
(zext ({x,+,y}) -> {zext(x),+,sext(y)}
Note that we are sign extending the increment of the SCEV, even for
the zext case. This is needed to cover the fairly common case where y would
be a (small) negative integer. In order to do this, this change adds two new
flags: nusw and nssw that are applicable to AddRecExprs and permit the
transformations above.
We also change isStridedPtr in LAA to be able to make use of
these predicates. With this feature we should now always be able to
work around overflow issues in the dependence analysis.
Reviewers: mzolotukhin, sanjoy, anemet
Subscribers: mzolotukhin, sanjoy, llvm-commits, rengolin, jmolloy, hfinkel
Differential Revision: http://reviews.llvm.org/D15412
llvm-svn: 260112
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/LoopAccessAnalysis.cpp | 61 | ||||
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 198 | ||||
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolutionExpander.cpp | 68 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopVersioning.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 2 |
5 files changed, 300 insertions, 30 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index a2ab231a62d..f371feb8142 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -773,7 +773,7 @@ static bool isInBoundsGep(Value *Ptr) { /// \brief Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, /// i.e. monotonically increasing/decreasing. static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, - ScalarEvolution *SE, const Loop *L) { + PredicatedScalarEvolution &PSE, const Loop *L) { // FIXME: This should probably only return true for NUW. if (AR->getNoWrapFlags(SCEV::NoWrapMask)) return true; @@ -809,7 +809,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, // Assume constant for other the operand so that the AddRec can be // easily found. isa<ConstantInt>(OBO->getOperand(1))) { - auto *OpScev = SE->getSCEV(OBO->getOperand(0)); + auto *OpScev = PSE.getSCEV(OBO->getOperand(0)); if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev)) return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW); @@ -820,31 +820,35 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, /// \brief Check whether the access through \p Ptr has a constant stride. int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, - const Loop *Lp, const ValueToValueMap &StridesMap) { + const Loop *Lp, const ValueToValueMap &StridesMap, + bool Assume) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); // Make sure that the pointer does not point to aggregate types. auto *PtrTy = cast<PointerType>(Ty); if (PtrTy->getElementType()->isAggregateType()) { - DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" - << *Ptr << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" << *Ptr + << "\n"); return 0; } const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); + if (Assume && !AR) + AR = dyn_cast<SCEVAddRecExpr>(PSE.getAsAddRec(Ptr)); + if (!AR) { - DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " - << *Ptr << " SCEV: " << *PtrScev << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr + << " SCEV: " << *PtrScev << "\n"); return 0; } // The accesss function must stride over the innermost loop. if (Lp != AR->getLoop()) { DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " << - *Ptr << " SCEV: " << *PtrScev << "\n"); + *Ptr << " SCEV: " << *AR << "\n"); return 0; } @@ -856,12 +860,23 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, // to access the pointer value "0" which is undefined behavior in address // space 0, therefore we can also vectorize this case. bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp); + bool IsNoWrapAddRec = + PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || + isNoWrapAddRec(Ptr, AR, PSE, Lp); bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0; if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) { - DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " - << *Ptr << " SCEV: " << *PtrScev << "\n"); - return 0; + if (Assume) { + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + IsNoWrapAddRec = true; + DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + } else { + DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " + << *Ptr << " SCEV: " << *AR << "\n"); + return 0; + } } // Check the step is constant. @@ -871,7 +886,7 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); if (!C) { DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr << - " SCEV: " << *PtrScev << "\n"); + " SCEV: " << *AR << "\n"); return 0; } @@ -895,8 +910,18 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, // know we can't "wrap around the address space". In case of address space // zero we know that this won't happen without triggering undefined behavior. if (!IsNoWrapAddRec && (IsInBoundsGEP || IsInAddressSpaceZero) && - Stride != 1 && Stride != -1) - return 0; + Stride != 1 && Stride != -1) { + if (Assume) { + // We can avoid this case by adding a run-time check. + DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " + << "inbouds or in address space 0 may wrap:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } else + return 0; + } return Stride; } @@ -1123,8 +1148,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr); const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr); - int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides); + int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides, true); + int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides, true); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1824,7 +1849,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), + : PSE(*SE, *L), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 32bf11ac5c5..c618791272e 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9627,17 +9627,40 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, return Eq; } +const SCEVPredicate *ScalarEvolution::getWrapPredicate( + const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Wrap); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + namespace { + class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - SCEVUnionPredicate &A) { - SCEVPredicateRewriter Rewriter(SE, A); + // Rewrites Scev in the context of a loop L and the predicate A. + // If Assume is true, rewrite is free to add further predicates to A + // such that the result will be an AddRecExpr. + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE, SCEVUnionPredicate &A, + bool Assume) { + SCEVPredicateRewriter Rewriter(L, SE, A, Assume); return Rewriter.visit(Scev); } - SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) - : SCEVRewriteVisitor(SE), P(P) {} + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &P, bool Assume) + : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); @@ -9649,14 +9672,67 @@ public: return Expr; } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<const SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nusw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<const SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nssw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + private: + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + if (!Assume) { + // Check if we've already made this assumption. + if (P.implies(A)) + return true; + return false; + } + P.add(A); + return true; + } + SCEVUnionPredicate &P; + const Loop *L; + bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, false); +} + +const SCEV *ScalarEvolution::convertSCEVToAddRecWithPredicates( + const SCEV *Scev, const Loop *L, SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, L, *this, Preds, true); } /// SCEV predicates @@ -9686,6 +9762,59 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags) + : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + + return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); + IncrementWrapFlags IFlags = Flags; + + if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) + IFlags = clearFlags(IFlags, IncrementNSSW); + + return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::IncrementNUSW & getFlags()) + OS << "<nusw>"; + if (SCEVWrapPredicate::IncrementNSSW & getFlags()) + OS << "<nssw>"; + OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, + ScalarEvolution &SE) { + IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; + SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + + // We can safely transfer the NSW flag as NSSW. + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + ImpliedFlags = IncrementNSSW; + + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + // If the increment is positive, the SCEV NUW flag will also imply the + // WrapPredicate NUSW flag. + if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) + if (Step->getValue()->getValue().isNonNegative()) + ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + } + + return ImpliedFlags; +} + /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9742,8 +9871,9 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) - : SE(SE), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L), Generation(0) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9758,7 +9888,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; @@ -9780,7 +9910,55 @@ void PredicatedScalarEvolution::updateGeneration() { if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + + // Clear the statically implied flags. + Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + addPredicate(*SE.getWrapPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + Flags = SCEVWrapPredicate::clearFlags( + Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + + return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEV *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + const SCEV *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} + +PredicatedScalarEvolution:: +PredicatedScalarEvolution(const PredicatedScalarEvolution &Init) : + RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), + Generation(Init.Generation) { + for (auto I = Init.FlagsMap.begin(), E = Init.FlagsMap.end(); I != E; ++I) + FlagsMap.insert(*I); +} diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index 86cfe2d00ab..49cdcf75818 100644 --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1971,6 +1971,10 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP); case SCEVPredicate::P_Equal: return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP); + case SCEVPredicate::P_Wrap: { + auto *AddRecPred = cast<SCEVWrapPredicate>(Pred); + return expandWrapPredicate(AddRecPred, IP); + } } llvm_unreachable("Unknown SCEV predicate type"); } @@ -1985,6 +1989,70 @@ Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, return I; } +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, bool Signed) { + assert(AR->isAffine() && "Cannot generate RT check for " + "non-affine expression"); + + const SCEV *ExitCount = SE.getBackedgeTakenCount(AR->getLoop()); + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *Start = AR->getStart(); + + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); + unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); + unsigned MaxBits = 2 * std::max(DstBits, SrcBits); + + auto *TripCount = SE.getTruncateOrZeroExtend(ExitCount, AR->getType()); + IntegerType *MaxTy = IntegerType::get(Loc->getContext(), MaxBits); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + + const auto *ExtendedTripCount = SE.getZeroExtendExpr(ExitCount, MaxTy); + const auto *ExtendedStep = SE.getSignExtendExpr(Step, MaxTy); + const auto *ExtendedStart = Signed ? SE.getSignExtendExpr(Start, MaxTy) + : SE.getZeroExtendExpr(Start, MaxTy); + + const SCEV *End = SE.getAddExpr(Start, SE.getMulExpr(TripCount, Step)); + const SCEV *RHS = Signed ? SE.getSignExtendExpr(End, MaxTy) + : SE.getZeroExtendExpr(End, MaxTy); + + const SCEV *LHS = SE.getAddExpr( + ExtendedStart, SE.getMulExpr(ExtendedTripCount, ExtendedStep)); + + // Do all SCEV expansions now. + Value *LHSVal = expandCodeFor(LHS, MaxTy, Loc); + Value *RHSVal = expandCodeFor(RHS, MaxTy, Loc); + + Builder.SetInsertPoint(Loc); + + return Builder.CreateICmp(ICmpInst::ICMP_NE, RHSVal, LHSVal); +} + +Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, + Instruction *IP) { + const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr()); + Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; + + // Add a check for NUSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) + NUSWCheck = generateOverflowCheck(A, IP, false); + + // Add a check for NSSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) + NSSWCheck = generateOverflowCheck(A, IP, true); + + if (NUSWCheck && NSSWCheck) + return Builder.CreateOr(NUSWCheck, NSSWCheck); + + if (NUSWCheck) + return NUSWCheck; + + if (NSSWCheck) + return NSSWCheck; + + return ConstantInt::getFalse(IP->getContext()); +} + Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { auto *BoolType = IntegerType::get(IP->getContext(), 1); diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 1aeffb79b19..9303dee99db 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -56,7 +56,6 @@ void LoopVersioning::versionLoop( BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); std::tie(FirstCheckInst, MemRuntimeCheck) = LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks); - assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate(); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index b885dbe857a..c5a10ba7b60 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1751,7 +1751,7 @@ struct LoopVectorize : public FunctionPass { } } - PredicatedScalarEvolution PSE(*SE); + PredicatedScalarEvolution PSE(*SE, *L); // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; |