diff options
author | Silviu Baranga <silviu.baranga@arm.com> | 2015-12-09 16:06:28 +0000 |
---|---|---|
committer | Silviu Baranga <silviu.baranga@arm.com> | 2015-12-09 16:06:28 +0000 |
commit | 9cd9a7e310933b147e22220cdcea4bd6263be5fb (patch) | |
tree | 75f376610145b52f5735abb4515412adedcb3d9a /llvm/lib | |
parent | e69df38282957733ad88ff44089892013318f7b4 (diff) | |
download | bcm5719-llvm-9cd9a7e310933b147e22220cdcea4bd6263be5fb.tar.gz bcm5719-llvm-9cd9a7e310933b147e22220cdcea4bd6263be5fb.zip |
Re-commit r255115, with the PredicatedScalarEvolution class moved to
ScalarEvolution.h, in order to avoid cyclic dependencies between the Transform
and Analysis modules:
[LV][LAA] Add a layer over SCEV to apply run-time checked knowledge on SCEV expressions
Summary:
This change creates a layer over ScalarEvolution for LAA and LV, and centralizes the
usage of SCEV predicates. The SCEVPredicatedLayer takes the statically deduced knowledge
by ScalarEvolution and applies the knowledge from the SCEV predicates. The end goal is
that both LAA and LV should use this interface everywhere.
This also solves a problem involving the result of SCEV expression rewritting when
the predicate changes. Suppose we have the expression (sext {a,+,b}) and two predicates
P1: {a,+,b} has nsw
P2: b = 1.
Applying P1 and then P2 gives us {a,+,1}, while applying P2 and the P1 gives us
sext({a,+,1}) (the AddRec expression was changed by P2 so P1 no longer applies).
The SCEVPredicatedLayer maintains the order of transformations by feeding back
the results of previous transformations into new transformations, and therefore
avoiding this issue.
The SCEVPredicatedLayer maintains a cache to remember the results of previous
SCEV rewritting results. This also has the benefit of reducing the overall number
of expression rewrites.
Reviewers: mzolotukhin, anemet
Subscribers: jmolloy, sanjoy, llvm-commits
Differential Revision: http://reviews.llvm.org/D14296
llvm-svn: 255122
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/LoopAccessAnalysis.cpp | 89 | ||||
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 43 | ||||
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopDistribute.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopVersioning.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 165 |
6 files changed, 176 insertions, 136 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index b2670bf48dd..ce6a5ab5656 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -87,11 +87,10 @@ Value *llvm::stripIntegerCast(Value *V) { return V; } -const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, +const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, const ValueToValueMap &PtrToStride, - SCEVUnionPredicate &Preds, Value *Ptr, Value *OrigPtr) { - const SCEV *OrigSCEV = SE->getSCEV(Ptr); + const SCEV *OrigSCEV = PSE.getSCEV(Ptr); // If there is an entry in the map return the SCEV of the pointer with the // symbolic stride replaced by one. @@ -108,16 +107,17 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, ValueToValueMap RewriteMap; RewriteMap[StrideVal] = One; + ScalarEvolution *SE = PSE.getSE(); const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal)); const auto *CT = static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType())); - Preds.add(SE->getEqualPredicate(U, CT)); + PSE.addPredicate(*SE->getEqualPredicate(U, CT)); + auto *Expr = PSE.getSCEV(Ptr); - const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds); - DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne + DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *Expr << "\n"); - return ByOne; + return Expr; } // Otherwise, just return the SCEV of the original pointer. @@ -127,11 +127,12 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE, void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId, const ValueToValueMap &Strides, - SCEVUnionPredicate &Preds) { + PredicatedScalarEvolution &PSE) { // Get the stride replaced scev. - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); + const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); assert(AR && "Invalid addrec expression"); + ScalarEvolution *SE = PSE.getSE(); const SCEV *Ex = SE->getBackedgeTakenCount(Lp); const SCEV *ScStart = AR->getStart(); @@ -423,9 +424,10 @@ public: typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, - MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds) + MemoryDepChecker::DepCandidates &DA, + PredicatedScalarEvolution &PSE) : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), - Preds(Preds) {} + PSE(PSE) {} /// \brief Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { @@ -512,16 +514,16 @@ private: bool IsRTCheckAnalysisNeeded; /// The SCEV predicate containing all the SCEV-related assumptions. - SCEVUnionPredicate &Preds; + PredicatedScalarEvolution &PSE; }; } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. -static bool hasComputableBounds(ScalarEvolution *SE, +static bool hasComputableBounds(PredicatedScalarEvolution &PSE, const ValueToValueMap &Strides, Value *Ptr, - Loop *L, SCEVUnionPredicate &Preds) { - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); + Loop *L) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); if (!AR) return false; @@ -564,11 +566,11 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, else ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) && + if (hasComputableBounds(PSE, StridesMap, Ptr, TheLoop) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. (!ShouldCheckStride || - isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) { + isStridedPtr(PSE, Ptr, TheLoop, StridesMap) == 1)) { // The id of the dependence set. unsigned DepId; @@ -582,7 +584,7 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds); + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -817,9 +819,8 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, } /// \brief Check whether the access through \p Ptr has a constant stride. -int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap, - SCEVUnionPredicate &Preds) { +int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, + const Loop *Lp, const ValueToValueMap &StridesMap) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -831,7 +832,7 @@ int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, return 0; } - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); if (!AR) { @@ -854,16 +855,16 @@ int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp, // to access the pointer value "0" which is undefined behavior in address // space 0, therefore we can also vectorize this case. bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, SE, Lp); + bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp); bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0; if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) { DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " - << *Ptr << " SCEV: " << *PtrScev << "\n"); + << *Ptr << " SCEV: " << *PtrScev << "\n"); return 0; } // Check the step is constant. - const SCEV *Step = AR->getStepRecurrence(*SE); + const SCEV *Step = AR->getStepRecurrence(*PSE.getSE()); // Calculate the pointer stride and check if it is constant. const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); @@ -1046,11 +1047,11 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr); + const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr); + const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr); - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds); - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds); + int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides); + int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -1067,12 +1068,12 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, std::swap(StrideAPtr, StrideBPtr); } - const SCEV *Dist = SE->getMinusSCEV(Sink, Src); + const SCEV *Dist = PSE.getSE()->getMinusSCEV(Sink, Src); DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink - << "(Induction step: " << StrideAPtr << ")\n"); + << "(Induction step: " << StrideAPtr << ")\n"); DEBUG(dbgs() << "LAA: Distance for " << *InstMap[AIdx] << " to " - << *InstMap[BIdx] << ": " << *Dist << "\n"); + << *InstMap[BIdx] << ": " << *Dist << "\n"); // Need accesses with constant stride. We don't want to vectorize // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in @@ -1343,10 +1344,10 @@ bool LoopAccessInfo::canAnalyzeLoop() { } // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = SE->getBackedgeTakenCount(TheLoop); - if (ExitCount == SE->getCouldNotCompute()) { - emitAnalysis(LoopAccessReport() << - "could not determine number of loop iterations"); + const SCEV *ExitCount = PSE.getSE()->getBackedgeTakenCount(TheLoop); + if (ExitCount == PSE.getSE()->getCouldNotCompute()) { + emitAnalysis(LoopAccessReport() + << "could not determine number of loop iterations"); DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n"); return false; } @@ -1447,7 +1448,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses, Preds); + AA, LI, DependentAccesses, PSE); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1498,8 +1499,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || - !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) { + if (Seen.insert(Ptr).second || !isStridedPtr(PSE, Ptr, TheLoop, Strides)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1529,7 +1529,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. bool CanDoRTIfNeeded = - Accesses.canCheckPtrAtRT(PtrRtChecking, SE, TheLoop, Strides); + Accesses.canCheckPtrAtRT(PtrRtChecking, PSE.getSE(), TheLoop, Strides); if (!CanDoRTIfNeeded) { emitAnalysis(LoopAccessReport() << "cannot identify array bounds"); DEBUG(dbgs() << "LAA: We can't vectorize because we can't find " @@ -1556,6 +1556,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { PtrRtChecking.reset(); PtrRtChecking.Need = true; + auto *SE = PSE.getSE(); CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(PtrRtChecking, SE, TheLoop, Strides, true); @@ -1598,7 +1599,7 @@ void LoopAccessInfo::emitAnalysis(LoopAccessReport &Message) { } bool LoopAccessInfo::isUniform(Value *V) const { - return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); + return (PSE.getSE()->isLoopInvariant(PSE.getSE()->getSCEV(V), TheLoop)); } // FIXME: this function is currently a duplicate of the one in @@ -1679,7 +1680,7 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks( Instruction *Loc, const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks) const { - + auto *SE = PSE.getSE(); SCEVExpander Exp(*SE, DL, "induction"); auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, PtrRtChecking); @@ -1749,7 +1750,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL), + : PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { @@ -1786,7 +1787,7 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { << "found in loop.\n"; OS.indent(Depth) << "SCEV assumptions:\n"; - Preds.print(OS, Depth); + PSE.getUnionPredicate().print(OS, Depth); } const LoopAccessInfo & diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index f57997e146e..1c2fb3d1ed0 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9707,3 +9707,46 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { SCEVToPreds[Key].push_back(N); Preds.push_back(N); } + +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) + : SE(SE), Generation(0) {} + +const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { + const SCEV *Expr = SE.getSCEV(V); + RewriteEntry &Entry = RewriteMap[Expr]; + + // If we already have an entry and the version matches, return it. + if (Entry.second && Generation == Entry.first) + return Entry.second; + + // We found an entry but it's stale. Rewrite the stale entry + // acording to the current predicate. + if (Entry.second) + Expr = Entry.second; + + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + Entry = {Generation, NewSCEV}; + + return NewSCEV; +} + +void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { + if (Preds.implies(&Pred)) + return; + Preds.add(&Pred); + updateGeneration(); +} + +const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { + return Preds; +} + +void PredicatedScalarEvolution::updateGeneration() { + // If the generation number wrapped recompute everything. + if (++Generation == 0) { + for (auto &II : RewriteMap) { + const SCEV *Rewritten = II.second.second; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + } + } +} diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 67ebd2532b1..fce063ab40a 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -761,7 +761,7 @@ private: } // Don't distribute the loop if we need too many SCEV run-time checks. - const SCEVUnionPredicate &Pred = LAI.Preds; + const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate(); if (Pred.getComplexity() > DistributeSCEVCheckThreshold) { DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); return false; @@ -790,7 +790,7 @@ private: DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); LoopVersioning LVer(LAI, L, LI, DT, SE, false); LVer.setAliasChecks(std::move(Checks)); - LVer.setSCEVChecks(LAI.Preds); + LVer.setSCEVChecks(LAI.PSE.getUnionPredicate()); LVer.versionLoop(DefsUsedOutside); } diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 7c7bf64ba79..09d022b3013 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -459,17 +459,18 @@ public: return false; } - if (LAI.Preds.getComplexity() > LoadElimSCEVCheckThreshold) { + if (LAI.PSE.getUnionPredicate().getComplexity() > + LoadElimSCEVCheckThreshold) { DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); return false; } // Point of no-return, start the transformation. First, version the loop if // necessary. - if (!Checks.empty() || !LAI.Preds.isAlwaysTrue()) { + if (!Checks.empty() || !LAI.PSE.getUnionPredicate().isAlwaysTrue()) { LoopVersioning LV(LAI, L, LI, DT, SE, false); LV.setAliasChecks(std::move(Checks)); - LV.setSCEVChecks(LAI.Preds); + LV.setSCEVChecks(LAI.PSE.getUnionPredicate()); LV.versionLoop(); } diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index cc3ff5d80d4..9a2a06cf689 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -32,7 +32,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, assert(L->getLoopPreheader() && "No preheader"); if (UseLAIChecks) { setAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); - setSCEVChecks(LAI.Preds); + setSCEVChecks(LAI.PSE.getUnionPredicate()); } } @@ -58,7 +58,7 @@ void LoopVersioning::versionLoop( LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks); assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); - const SCEVUnionPredicate &Pred = LAI.Preds; + const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate(); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), "scev.check"); SCEVRuntimeCheck = diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 917f2d55f6c..9adc80c8bd0 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -310,15 +310,16 @@ static GetElementPtrInst *getGEPInstruction(Value *Ptr) { /// and reduction variables that were found to a given vectorization factor. class InnerLoopVectorizer { public: - InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, - DominatorTree *DT, const TargetLibraryInfo *TLI, + InnerLoopVectorizer(Loop *OrigLoop, PredicatedScalarEvolution &PSE, + LoopInfo *LI, DominatorTree *DT, + const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor, SCEVUnionPredicate &Preds) - : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), - VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), + unsigned UnrollFactor) + : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), + VF(VecWidth), UF(UnrollFactor), Builder(PSE.getSE()->getContext()), Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr), - AddedSafetyChecks(false), Preds(Preds) {} + AddedSafetyChecks(false) {} // Perform the actual loop widening (vectorization). // MinimumBitWidths maps scalar integer values to the smallest bitwidth they @@ -486,8 +487,10 @@ protected: /// The original loop. Loop *OrigLoop; - /// Scev analysis to use. - ScalarEvolution *SE; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. Applies + /// dynamic knowledge to simplify SCEV expressions and converts them to a + /// more usable form. + PredicatedScalarEvolution &PSE; /// Loop Info. LoopInfo *LI; /// Dominator Tree. @@ -551,23 +554,15 @@ protected: // Record whether runtime check is added. bool AddedSafetyChecks; - - /// The SCEV predicate containing all the SCEV-related assumptions. - /// The predicate is used to simplify existing expressions in the - /// context of existing SCEV assumptions. Since legality checking is - /// not done here, we don't need to use this predicate to record - /// further assumptions. - SCEVUnionPredicate &Preds; }; class InnerLoopUnroller : public InnerLoopVectorizer { public: - InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, - DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor, - SCEVUnionPredicate &Preds) - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor, - Preds) {} + InnerLoopUnroller(Loop *OrigLoop, PredicatedScalarEvolution &PSE, + LoopInfo *LI, DominatorTree *DT, + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, unsigned UnrollFactor) + : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, 1, UnrollFactor) {} private: void scalarizeInstruction(Instruction *Instr, @@ -789,9 +784,9 @@ private: /// between the member and the group in a map. class InterleavedAccessInfo { public: - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT, - SCEVUnionPredicate &Preds) - : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {} + InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, + DominatorTree *DT) + : PSE(PSE), TheLoop(L), DT(DT) {} ~InterleavedAccessInfo() { SmallSet<InterleaveGroup *, 4> DelSet; @@ -821,17 +816,14 @@ public: } private: - ScalarEvolution *SE; + /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. + /// Simplifies SCEV expressions in the context of existing SCEV assumptions. + /// The interleaved access analysis can also add new predicates (for example + /// by versioning strides of pointers). + PredicatedScalarEvolution &PSE; Loop *TheLoop; DominatorTree *DT; - /// The SCEV predicate containing all the SCEV-related assumptions. - /// The predicate is used to simplify SCEV expressions in the - /// context of existing SCEV assumptions. The interleaved access - /// analysis can also add new predicates (for example by versioning - /// strides of pointers). - SCEVUnionPredicate &Preds; - /// Holds the relationships between the members and the interleave group. DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap; @@ -1189,18 +1181,17 @@ static void emitMissedWarning(Function *F, Loop *L, /// induction variable and the different reduction variables. class LoopVectorizationLegality { public: - LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, DominatorTree *DT, - TargetLibraryInfo *TLI, AliasAnalysis *AA, - Function *F, const TargetTransformInfo *TTI, + LoopVectorizationLegality(Loop *L, PredicatedScalarEvolution &PSE, + DominatorTree *DT, TargetLibraryInfo *TLI, + AliasAnalysis *AA, Function *F, + const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA, LoopVectorizationRequirements *R, - const LoopVectorizeHints *H, - SCEVUnionPredicate &Preds) - : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), - InterleaveInfo(SE, L, DT, Preds), Induction(nullptr), - WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H), - Preds(Preds) {} + const LoopVectorizeHints *H) + : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TheFunction(F), + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(PSE, L, DT), + Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), + Requirements(R), Hints(H) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1347,8 +1338,12 @@ private: /// The loop that we evaluate. Loop *TheLoop; - /// Scev analysis. - ScalarEvolution *SE; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. + /// Applies dynamic knowledge to simplify SCEV expressions in the context + /// of existing SCEV assumptions. The analysis will also add a minimal set + /// of new predicates if this is required to enable vectorization and + /// unrolling. + PredicatedScalarEvolution &PSE; /// Target Library Info. TargetLibraryInfo *TLI; /// Parent function @@ -1403,13 +1398,6 @@ private: /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic SmallPtrSet<const Instruction *, 8> MaskedOp; - - /// The SCEV predicate containing all the SCEV-related assumptions. - /// The predicate is used to simplify SCEV expressions in the - /// context of existing SCEV assumptions. The analysis will also - /// add a minimal set of new predicates if this is required to - /// enable vectorization/unrolling. - SCEVUnionPredicate &Preds; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1427,8 +1415,7 @@ public: const TargetLibraryInfo *TLI, DemandedBits *DB, AssumptionCache *AC, const Function *F, const LoopVectorizeHints *Hints, - SmallPtrSetImpl<const Value *> &ValuesToIgnore, - SCEVUnionPredicate &Preds) + SmallPtrSetImpl<const Value *> &ValuesToIgnore) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} @@ -1758,12 +1745,12 @@ struct LoopVectorize : public FunctionPass { } } - SCEVUnionPredicate Preds; + PredicatedScalarEvolution PSE(*SE); // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; - LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints, Preds); + LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, LAA, + &Requirements, &Hints); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1781,8 +1768,8 @@ struct LoopVectorize : public FunctionPass { } // Use the cost model. - LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints, - ValuesToIgnore, Preds); + LoopVectorizationCostModel CM(L, PSE.getSE(), LI, &LVL, *TTI, TLI, DB, AC, + F, &Hints, ValuesToIgnore); // Check the function attributes to find out if this function should be // optimized for size. @@ -1893,7 +1880,7 @@ struct LoopVectorize : public FunctionPass { assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop then // interleave it. - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds); + InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, IC); Unroller.vectorize(&LVL, CM.MinBWs); emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), @@ -1901,7 +1888,7 @@ struct LoopVectorize : public FunctionPass { Twine(IC) + ")"); } else { // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds); + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, VF.Width, IC); LB.vectorize(&LVL, CM.MinBWs); ++LoopsVectorized; @@ -2002,6 +1989,7 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr"); + auto *SE = PSE.getSE(); // Make sure that the pointer does not point to structs. if (Ptr->getType()->getPointerElementType()->isAggregateType()) return 0; @@ -2031,7 +2019,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { // Make sure that all of the index operands are loop invariant. for (unsigned i = 1; i < NumOperands; ++i) - if (!SE->isLoopInvariant(SE->getSCEV(Gep->getOperand(i)), TheLoop)) + if (!SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop)) return 0; InductionDescriptor II = Inductions[Phi]; @@ -2044,14 +2032,14 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { // operand. for (unsigned i = 0; i != NumOperands; ++i) if (i != InductionOperand && - !SE->isLoopInvariant(SE->getSCEV(Gep->getOperand(i)), TheLoop)) + !SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop)) return 0; // We can emit wide load/stores only if the last non-zero index is the // induction variable. const SCEV *Last = nullptr; if (!Strides.count(Gep)) - Last = SE->getSCEV(Gep->getOperand(InductionOperand)); + Last = PSE.getSCEV(Gep->getOperand(InductionOperand)); else { // Because of the multiplication by a stride we can have a s/zext cast. // We are going to replace this stride by 1 so the cast is safe to ignore. @@ -2062,7 +2050,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(SE, Strides, Preds, + Last = replaceSymbolicStrideSCEV(PSE, Strides, Gep->getOperand(InductionOperand), Gep); if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last)) Last = @@ -2420,8 +2408,9 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Ptr = Builder.Insert(Gep2); } else if (Gep) { setDebugLocFromInst(Builder, Gep); - assert(SE->isLoopInvariant(SE->getSCEV(Gep->getPointerOperand()), - OrigLoop) && "Base ptr must be invariant"); + assert(PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getPointerOperand()), + OrigLoop) && + "Base ptr must be invariant"); // The last index does not have to be the induction. It can be // consecutive and be a function of the index. For example A[I+1]; @@ -2438,7 +2427,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { if (i == InductionOperand || (GepOperandInst && OrigLoop->contains(GepOperandInst))) { assert((i == InductionOperand || - SE->isLoopInvariant(SE->getSCEV(GepOperandInst), OrigLoop)) && + PSE.getSE()->isLoopInvariant(PSE.getSCEV(GepOperandInst), + OrigLoop)) && "Must be last index or loop invariant"); VectorParts &GEPParts = getVectorValue(GepOperand); @@ -2658,6 +2648,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); // Find the loop boundaries. + ScalarEvolution *SE = PSE.getSE(); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(OrigLoop); assert(BackedgeTakenCount != SE->getCouldNotCompute() && "Invalid loop count"); @@ -2765,8 +2756,10 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { // Generate the code to check that the SCEV assumptions that we made. // We want the new basic block to start at the first instruction in a // sequence of instructions that form a check. - SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check"); - Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator()); + SCEVExpander Exp(*PSE.getSE(), Bypass->getModule()->getDataLayout(), + "scev.check"); + Value *SCEVCheck = + Exp.expandCodeForPredicate(&PSE.getUnionPredicate(), BB->getTerminator()); if (auto *C = dyn_cast<ConstantInt>(SCEVCheck)) if (C->isZero()) @@ -3785,8 +3778,9 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // Widen selects. // If the selector is loop invariant we can create a select // instruction with a scalar condition. Otherwise, use vector-select. - bool InvariantCond = SE->isLoopInvariant(SE->getSCEV(it->getOperand(0)), - OrigLoop); + auto *SE = PSE.getSE(); + bool InvariantCond = + SE->isLoopInvariant(PSE.getSCEV(it->getOperand(0)), OrigLoop); setDebugLocFromInst(Builder, &*it); // The condition can be loop invariant but still defined inside the @@ -3967,7 +3961,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { void InnerLoopVectorizer::updateAnalysis() { // Forget the original basic block. - SE->forgetLoop(OrigLoop); + PSE.getSE()->forgetLoop(OrigLoop); // Update the dominator tree information. assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && @@ -4119,10 +4113,10 @@ bool LoopVectorizationLegality::canVectorize() { } // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = SE->getBackedgeTakenCount(TheLoop); - if (ExitCount == SE->getCouldNotCompute()) { - emitAnalysis(VectorizationReport() << - "could not determine number of loop iterations"); + const SCEV *ExitCount = PSE.getSE()->getBackedgeTakenCount(TheLoop); + if (ExitCount == PSE.getSE()->getCouldNotCompute()) { + emitAnalysis(VectorizationReport() + << "could not determine number of loop iterations"); DEBUG(dbgs() << "LV: SCEV could not compute the loop exit count.\n"); return false; } @@ -4162,7 +4156,7 @@ bool LoopVectorizationLegality::canVectorize() { if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; - if (Preds.getComplexity() > SCEVThreshold) { + if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { emitAnalysis(VectorizationReport() << "Too many SCEV assumptions need to be made and checked " << "at runtime"); @@ -4268,7 +4262,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(Phi, SE, ID)) { + if (InductionDescriptor::isInductionPHI(Phi, PSE.getSE(), ID)) { Inductions[Phi] = ID; // Get the widest type. if (!WidestIndTy) @@ -4337,7 +4331,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // second argument is the same (i.e. loop invariant) if (CI && hasVectorInstrinsicScalarOpd(getIntrinsicIDForCall(CI, TLI), 1)) { - if (!SE->isLoopInvariant(SE->getSCEV(CI->getOperand(1)), TheLoop)) { + auto *SE = PSE.getSE(); + if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { emitAnalysis(VectorizationReport(&*it) << "intrinsic instruction cannot be vectorized"); DEBUG(dbgs() << "LV: Found unvectorizable intrinsic " << *CI << "\n"); @@ -4410,7 +4405,7 @@ void LoopVectorizationLegality::collectStridedAccess(Value *MemAccess) { else return; - Value *Stride = getStrideFromPointer(Ptr, SE, TheLoop); + Value *Stride = getStrideFromPointer(Ptr, PSE.getSE(), TheLoop); if (!Stride) return; @@ -4474,7 +4469,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); - Preds.add(&LAI->Preds); + PSE.addPredicate(LAI->PSE.getUnionPredicate()); return true; } @@ -4589,7 +4584,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses( StoreInst *SI = dyn_cast<StoreInst>(I); Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds); + int Stride = isStridedPtr(PSE, Ptr, TheLoop, Strides); // The factor of the corresponding interleave group. unsigned Factor = std::abs(Stride); @@ -4598,7 +4593,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses( if (Factor < 2 || Factor > MaxInterleaveGroupFactor) continue; - const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr); + const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType()); @@ -4685,8 +4680,8 @@ void InterleavedAccessInfo::analyzeInterleaving( continue; // Calculate the distance and prepare for the rule 3. - const SCEVConstant *DistToA = - dyn_cast<SCEVConstant>(SE->getMinusSCEV(DesB.Scev, DesA.Scev)); + const SCEVConstant *DistToA = dyn_cast<SCEVConstant>( + PSE.getSE()->getMinusSCEV(DesB.Scev, DesA.Scev)); if (!DistToA) continue; |