summaryrefslogtreecommitdiffstats
path: root/llvm
diff options
context:
space:
mode:
Diffstat (limited to 'llvm')
-rw-r--r--llvm/include/llvm/Analysis/ScalarEvolution.h48
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp385
-rw-r--r--llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll240
3 files changed, 648 insertions, 25 deletions
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c7accfae78b..d1b182755cf 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -237,17 +237,15 @@ struct FoldingSetTrait<SCEVPredicate> : DefaultFoldingSetTrait<SCEVPredicate> {
};
/// This class represents an assumption that two SCEV expressions are equal,
-/// and this can be checked at run-time. We assume that the left hand side is
-/// a SCEVUnknown and the right hand side a constant.
+/// and this can be checked at run-time.
class SCEVEqualPredicate final : public SCEVPredicate {
- /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
- /// constant.
- const SCEVUnknown *LHS;
- const SCEVConstant *RHS;
+ /// We assume that LHS == RHS.
+ const SCEV *LHS;
+ const SCEV *RHS;
public:
- SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
- const SCEVConstant *RHS);
+ SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEV *LHS,
+ const SCEV *RHS);
/// Implementation of the SCEVPredicate interface
bool implies(const SCEVPredicate *N) const override;
@@ -256,10 +254,10 @@ public:
const SCEV *getExpr() const override;
/// Returns the left hand side of the equality.
- const SCEVUnknown *getLHS() const { return LHS; }
+ const SCEV *getLHS() const { return LHS; }
/// Returns the right hand side of the equality.
- const SCEVConstant *getRHS() const { return RHS; }
+ const SCEV *getRHS() const { return RHS; }
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const SCEVPredicate *P) {
@@ -1241,6 +1239,14 @@ public:
SmallVector<const SCEV *, 4> NewOp(Operands.begin(), Operands.end());
return getAddRecExpr(NewOp, L, Flags);
}
+
+ /// Checks if \p SymbolicPHI can be rewritten as an AddRecExpr under some
+ /// Predicates. If successful return these <AddRecExpr, Predicates>;
+ /// The function is intended to be called from PSCEV (the caller will decide
+ /// whether to actually add the predicates and carry out the rewrites).
+ Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI);
+
/// Returns an expression for a GEP
///
/// \p GEP The GEP. The indices contained in the GEP itself are ignored,
@@ -1675,8 +1681,7 @@ public:
return F.getParent()->getDataLayout();
}
- const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
- const SCEVConstant *RHS);
+ const SCEVPredicate *getEqualPredicate(const SCEV *LHS, const SCEV *RHS);
const SCEVPredicate *
getWrapPredicate(const SCEVAddRecExpr *AR,
@@ -1692,6 +1697,19 @@ public:
SmallPtrSetImpl<const SCEVPredicate *> &Preds);
private:
+ /// Similar to createAddRecFromPHI, but with the additional flexibility of
+ /// suggesting runtime overflow checks in case casts are encountered.
+ /// If successful, the analysis records that for this loop, \p SymbolicPHI,
+ /// which is the UnknownSCEV currently representing the PHI, can be rewritten
+ /// into an AddRec, assuming some predicates; The function then returns the
+ /// AddRec and the predicates as a pair, and caches this pair in
+ /// PredicatedSCEVRewrites.
+ /// If the analysis is not successful, a mapping from the \p SymbolicPHI to
+ /// itself (with no predicates) is recorded, and a nullptr with an empty
+ /// predicates vector is returned as a pair.
+ Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI);
+
/// Compute the backedge taken count knowing the interval difference, the
/// stride and presence of the equality in the comparison.
const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride,
@@ -1722,6 +1740,12 @@ private:
FoldingSet<SCEVPredicate> UniquePreds;
BumpPtrAllocator SCEVAllocator;
+ /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression
+ /// they can be rewritten into under certain predicates.
+ DenseMap<std::pair<const SCEVUnknown *, const Loop *>,
+ std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ PredicatedSCEVRewrites;
+
/// The head of a linked list of all SCEVUnknown values that have been
/// allocated. This is used by releaseMemory to locate them all and call
/// their destructors.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 3fb1ab980ad..b973203a89b 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -4173,6 +4173,319 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
return None;
}
+/// Helper function to createAddRecFromPHIWithCasts. We have a phi
+/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
+/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
+/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
+/// follows one of the following patterns:
+/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
+/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
+/// If the SCEV expression of \p Op conforms with one of the expected patterns
+/// we return the type of the truncation operation, and indicate whether the
+/// truncated type should be treated as signed/unsigned by setting
+/// \p Signed to true/false, respectively.
+static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
+ bool &Signed, ScalarEvolution &SE) {
+
+ // The case where Op == SymbolicPHI (that is, with no type conversions on
+ // the way) is handled by the regular add recurrence creating logic and
+ // would have already been triggered in createAddRecForPHI. Reaching it here
+ // means that createAddRecFromPHI had failed for this PHI before (e.g.,
+ // because one of the other operands of the SCEVAddExpr updating this PHI is
+ // not invariant).
+ //
+ // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
+ // this case predicates that allow us to prove that Op == SymbolicPHI will
+ // be added.
+ if (Op == SymbolicPHI)
+ return nullptr;
+
+ unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
+ unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
+ if (SourceBits != NewBits)
+ return nullptr;
+
+ const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
+ const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
+ if (!SExt && !ZExt)
+ return nullptr;
+ const SCEVTruncateExpr *Trunc =
+ SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
+ : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
+ if (!Trunc)
+ return nullptr;
+ const SCEV *X = Trunc->getOperand();
+ if (X != SymbolicPHI)
+ return nullptr;
+ Signed = SExt ? true : false;
+ return Trunc->getType();
+}
+
+static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
+ if (!PN->getType()->isIntegerTy())
+ return nullptr;
+ const Loop *L = LI.getLoopFor(PN->getParent());
+ if (!L || L->getHeader() != PN->getParent())
+ return nullptr;
+ return L;
+}
+
+// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
+// computation that updates the phi follows the following pattern:
+// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
+// which correspond to a phi->trunc->sext/zext->add->phi update chain.
+// If so, try to see if it can be rewritten as an AddRecExpr under some
+// Predicates. If successful, return them as a pair. Also cache the results
+// of the analysis.
+//
+// Example usage scenario:
+// Say the Rewriter is called for the following SCEV:
+// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
+// where:
+// %X = phi i64 (%Start, %BEValue)
+// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
+// and call this function with %SymbolicPHI = %X.
+//
+// The analysis will find that the value coming around the backedge has
+// the following SCEV:
+// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
+// Upon concluding that this matches the desired pattern, the function
+// will return the pair {NewAddRec, SmallPredsVec} where:
+// NewAddRec = {%Start,+,%Step}
+// SmallPredsVec = {P1, P2, P3} as follows:
+// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
+// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
+// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
+// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
+// under the predicates {P1,P2,P3}.
+// This predicated rewrite will be cached in PredicatedSCEVRewrites:
+// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
+//
+// TODO's:
+//
+// 1) Extend the Induction descriptor to also support inductions that involve
+// casts: When needed (namely, when we are called in the context of the
+// vectorizer induction analysis), a Set of cast instructions will be
+// populated by this method, and provided back to isInductionPHI. This is
+// needed to allow the vectorizer to properly record them to be ignored by
+// the cost model and to avoid vectorizing them (otherwise these casts,
+// which are redundant under the runtime overflow checks, will be
+// vectorized, which can be costly).
+//
+// 2) Support additional induction/PHISCEV patterns: We also want to support
+// inductions where the sext-trunc / zext-trunc operations (partly) occur
+// after the induction update operation (the induction increment):
+//
+// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
+// which correspond to a phi->add->trunc->sext/zext->phi update chain.
+//
+// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
+// which correspond to a phi->trunc->add->sext/zext->phi update chain.
+//
+// 3) Outline common code with createAddRecFromPHI to avoid duplication.
+//
+Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
+ SmallVector<const SCEVPredicate *, 3> Predicates;
+
+ // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
+ // return an AddRec expression under some predicate.
+
+ auto *PN = cast<PHINode>(SymbolicPHI->getValue());
+ const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
+ assert (L && "Expecting an integer loop header phi");
+
+ // The loop may have multiple entrances or multiple exits; we can analyze
+ // this phi as an addrec if it has a unique entry value and a unique
+ // backedge value.
+ Value *BEValueV = nullptr, *StartValueV = nullptr;
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ Value *V = PN->getIncomingValue(i);
+ if (L->contains(PN->getIncomingBlock(i))) {
+ if (!BEValueV) {
+ BEValueV = V;
+ } else if (BEValueV != V) {
+ BEValueV = nullptr;
+ break;
+ }
+ } else if (!StartValueV) {
+ StartValueV = V;
+ } else if (StartValueV != V) {
+ StartValueV = nullptr;
+ break;
+ }
+ }
+ if (!BEValueV || !StartValueV)
+ return None;
+
+ const SCEV *BEValue = getSCEV(BEValueV);
+
+ // If the value coming around the backedge is an add with the symbolic
+ // value we just inserted, possibly with casts that we can ignore under
+ // an appropriate runtime guard, then we found a simple induction variable!
+ const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
+ if (!Add)
+ return None;
+
+ // If there is a single occurrence of the symbolic value, possibly
+ // casted, replace it with a recurrence.
+ unsigned FoundIndex = Add->getNumOperands();
+ Type *TruncTy = nullptr;
+ bool Signed;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if ((TruncTy =
+ isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
+ if (FoundIndex == e) {
+ FoundIndex = i;
+ break;
+ }
+
+ if (FoundIndex == Add->getNumOperands())
+ return None;
+
+ // Create an add with everything but the specified operand.
+ SmallVector<const SCEV *, 8> Ops;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if (i != FoundIndex)
+ Ops.push_back(Add->getOperand(i));
+ const SCEV *Accum = getAddExpr(Ops);
+
+ // The runtime checks will not be valid if the step amount is
+ // varying inside the loop.
+ if (!isLoopInvariant(Accum, L))
+ return None;
+
+
+ // *** Part2: Create the predicates
+
+ // Analysis was successful: we have a phi-with-cast pattern for which we
+ // can return an AddRec expression under the following predicates:
+ //
+ // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
+ // fits within the truncated type (does not overflow) for i = 0 to n-1.
+ // P2: An Equal predicate that guarantees that
+ // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
+ // P3: An Equal predicate that guarantees that
+ // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
+ //
+ // As we next prove, the above predicates guarantee that:
+ // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
+ //
+ //
+ // More formally, we want to prove that:
+ // Expr(i+1) = Start + (i+1) * Accum
+ // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
+ //
+ // Given that:
+ // 1) Expr(0) = Start
+ // 2) Expr(1) = Start + Accum
+ // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
+ // 3) Induction hypothesis (step i):
+ // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
+ //
+ // Proof:
+ // Expr(i+1) =
+ // = Start + (i+1)*Accum
+ // = (Start + i*Accum) + Accum
+ // = Expr(i) + Accum
+ // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
+ // :: from step i
+ //
+ // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
+ //
+ // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
+ // + (Ext ix (Trunc iy (Accum) to ix) to iy)
+ // + Accum :: from P3
+ //
+ // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
+ // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
+ //
+ // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
+ // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
+ //
+ // By induction, the same applies to all iterations 1<=i<n:
+ //
+
+ // Create a truncated addrec for which we will add a no overflow check (P1).
+ const SCEV *StartVal = getSCEV(StartValueV);
+ const SCEV *PHISCEV =
+ getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
+ getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
+ const auto *AR = cast<SCEVAddRecExpr>(PHISCEV);
+
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
+ Signed ? SCEVWrapPredicate::IncrementNSSW
+ : SCEVWrapPredicate::IncrementNUSW;
+ const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
+ Predicates.push_back(AddRecPred);
+
+ // Create the Equal Predicates P2,P3:
+ auto AppendPredicate = [&](const SCEV *Expr) -> void {
+ assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
+ const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
+ const SCEV *ExtendedExpr =
+ Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType())
+ : getZeroExtendExpr(TruncatedExpr, Expr->getType());
+ if (Expr != ExtendedExpr &&
+ !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
+ const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
+ DEBUG (dbgs() << "Added Predicate: " << *Pred);
+ Predicates.push_back(Pred);
+ }
+ };
+
+ AppendPredicate(StartVal);
+ AppendPredicate(Accum);
+
+ // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
+ // which the casts had been folded away. The caller can rewrite SymbolicPHI
+ // into NewAR if it will also add the runtime overflow checks specified in
+ // Predicates.
+ auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
+
+ std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
+ std::make_pair(NewAR, Predicates);
+ // Remember the result of the analysis for this SCEV at this locayyytion.
+ PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
+ return PredRewrite;
+}
+
+Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
+
+ auto *PN = cast<PHINode>(SymbolicPHI->getValue());
+ const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
+ if (!L)
+ return None;
+
+ // Check to see if we already analyzed this PHI.
+ auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
+ if (I != PredicatedSCEVRewrites.end()) {
+ std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
+ I->second;
+ // Analysis was done before and failed to create an AddRec:
+ if (Rewrite.first == SymbolicPHI)
+ return None;
+ // Analysis was done before and succeeded to create an AddRec under
+ // a predicate:
+ assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
+ assert(!(Rewrite.second).empty() && "Expected to find Predicates");
+ return Rewrite;
+ }
+
+ Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
+
+ // Record in the cache that the analysis failed
+ if (!Rewrite) {
+ SmallVector<const SCEVPredicate *, 3> Predicates;
+ PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
+ return None;
+ }
+
+ return Rewrite;
+}
+
/// A helper function for createAddRecFromPHI to handle simple cases.
///
/// This function tries to find an AddRec expression for the simplest (yet most
@@ -5904,6 +6217,16 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
RemoveLoopFromBackedgeMap(BackedgeTakenCounts);
RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts);
+ // Drop information about predicated SCEV rewrites for this loop.
+ for (auto I = PredicatedSCEVRewrites.begin();
+ I != PredicatedSCEVRewrites.end();) {
+ std::pair<const SCEV *, const Loop *> Entry = I->first;
+ if (Entry.second == L)
+ PredicatedSCEVRewrites.erase(I++);
+ else
+ ++I;
+ }
+
// Drop information about expressions based on loop-header PHIs.
SmallVector<Instruction *, 16> Worklist;
PushLoopPHIs(L, Worklist);
@@ -10062,6 +10385,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
+ PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
}
@@ -10462,6 +10786,15 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
HasRecMap.erase(S);
MinTrailingZerosCache.erase(S);
+ for (auto I = PredicatedSCEVRewrites.begin();
+ I != PredicatedSCEVRewrites.end();) {
+ std::pair<const SCEV *, const Loop *> Entry = I->first;
+ if (Entry.first == S)
+ PredicatedSCEVRewrites.erase(I++);
+ else
+ ++I;
+ }
+
auto RemoveSCEVFromBackedgeMap =
[S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
for (auto I = Map.begin(), E = Map.end(); I != E;) {
@@ -10621,10 +10954,11 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
}
-const SCEVPredicate *
-ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
- const SCEVConstant *RHS) {
+const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
+ const SCEV *RHS) {
FoldingSetNodeID ID;
+ assert(LHS->getType() == RHS->getType() &&
+ "Type mismatch between LHS and RHS");
// Unique this node based on the arguments
ID.AddInteger(SCEVPredicate::P_Equal);
ID.AddPointer(LHS);
@@ -10687,8 +11021,7 @@ public:
if (IPred->getLHS() == Expr)
return IPred->getRHS();
}
-
- return Expr;
+ return convertToAddRecWithPreds(Expr);
}
const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
@@ -10724,17 +11057,41 @@ public:
}
private:
- bool addOverflowAssumption(const SCEVAddRecExpr *AR,
- SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
- auto *A = SE.getWrapPredicate(AR, AddedFlags);
+ bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
- return Pred && Pred->implies(A);
+ return Pred && Pred->implies(P);
}
- NewPreds->insert(A);
+ NewPreds->insert(P);
return true;
}
+ bool addOverflowAssumption(const SCEVAddRecExpr *AR,
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+ auto *A = SE.getWrapPredicate(AR, AddedFlags);
+ return addOverflowAssumption(A);
+ }
+
+ // If \p Expr represents a PHINode, we try to see if it can be represented
+ // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
+ // to add this predicate as a runtime overflow check, we return the AddRec.
+ // If \p Expr does not meet these conditions (is not a PHI node, or we
+ // couldn't create an AddRec for it, or couldn't add the predicate), we just
+ // return \p Expr.
+ const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
+ if (!isa<PHINode>(Expr->getValue()))
+ return Expr;
+ Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
+ if (!PredicatedRewrite)
+ return Expr;
+ for (auto *P : PredicatedRewrite->second){
+ if (!addOverflowAssumption(P))
+ return Expr;
+ }
+ return PredicatedRewrite->first;
+ }
+
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
SCEVUnionPredicate *Pred;
const Loop *L;
@@ -10771,9 +11128,11 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
: FastID(ID), Kind(Kind) {}
SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
- const SCEVUnknown *LHS,
- const SCEVConstant *RHS)
- : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+ const SCEV *LHS, const SCEV *RHS)
+ : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {
+ assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
+ assert(LHS != RHS && "LHS and RHS are the same SCEV");
+}
bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
diff --git a/llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll b/llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll
new file mode 100644
index 00000000000..40af8f3adf0
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll
@@ -0,0 +1,240 @@
+; RUN: opt -S -loop-vectorize -force-vector-width=4 -force-vector-interleave=1 < %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+
+; Check that the vectorizer identifies the %p.09 phi,
+; as an induction variable, despite the potential overflow
+; due to the truncation from 32bit to 8bit.
+; SCEV will detect the pattern "sext(trunc(%p.09)) + %step"
+; and generate the required runtime checks under which
+; we can assume no overflow. We check here that we generate
+; exactly two runtime checks:
+; 1) an overflow check:
+; {0,+,(trunc i32 %step to i8)}<%for.body> Added Flags: <nssw>
+; 2) an equality check verifying that the step of the induction
+; is equal to sext(trunc(step)):
+; Equal predicate: %step == (sext i8 (trunc i32 %step to i8) to i32)
+;
+; See also pr30654.
+;
+; int a[N];
+; void doit1(int n, int step) {
+; int i;
+; char p = 0;
+; for (i = 0; i < n; i++) {
+; a[i] = p;
+; p = p + step;
+; }
+; }
+;
+
+; CHECK-LABEL: @doit1
+; CHECK: vector.scevcheck
+; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: %[[TEST:[0-9]+]] = or i1 {{.*}}, %mul.overflow
+; CHECK: %[[NTEST:[0-9]+]] = or i1 false, %[[TEST]]
+; CHECK: %ident.check = icmp ne i32 {{.*}}, %{{.*}}
+; CHECK: %{{.*}} = or i1 %[[NTEST]], %ident.check
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: vector.body:
+; CHECK: <4 x i32>
+
+@a = common local_unnamed_addr global [250 x i32] zeroinitializer, align 16
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit1(i32 %n, i32 %step) local_unnamed_addr {
+entry:
+ %cmp7 = icmp sgt i32 %n, 0
+ br i1 %cmp7, label %for.body.preheader, label %for.end
+
+for.body.preheader:
+ %wide.trip.count = zext i32 %n to i64
+ br label %for.body
+
+for.body:
+ %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+ %p.09 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+ %sext = shl i32 %p.09, 24
+ %conv = ashr exact i32 %sext, 24
+ %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+ store i32 %conv, i32* %arrayidx, align 4
+ %add = add nsw i32 %conv, %step
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
+
+; Same as above, but for checking the SCEV "zext(trunc(%p.09)) + %step".
+; Here we expect the following two predicates to be added for runtime checking:
+; 1) {0,+,(trunc i32 %step to i8)}<%for.body> Added Flags: <nusw>
+; 2) Equal predicate: %step == (zext i8 (trunc i32 %step to i8) to i32)
+;
+; int a[N];
+; void doit2(int n, int step) {
+; int i;
+; unsigned char p = 0;
+; for (i = 0; i < n; i++) {
+; a[i] = p;
+; p = p + step;
+; }
+; }
+;
+
+; CHECK-LABEL: @doit2
+; CHECK: vector.scevcheck
+; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: %[[TEST:[0-9]+]] = or i1 {{.*}}, %mul.overflow
+; CHECK: %[[NTEST:[0-9]+]] = or i1 false, %[[TEST]]
+; CHECK: %ident.check = icmp ne i32 {{.*}}, %{{.*}}
+; CHECK: %{{.*}} = or i1 %[[NTEST]], %ident.check
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: vector.body:
+; CHECK: <4 x i32>
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit2(i32 %n, i32 %step) local_unnamed_addr {
+entry:
+ %cmp7 = icmp sgt i32 %n, 0
+ br i1 %cmp7, label %for.body.preheader, label %for.end
+
+for.body.preheader:
+ %wide.trip.count = zext i32 %n to i64
+ br label %for.body
+
+for.body:
+ %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+ %p.09 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+ %conv = and i32 %p.09, 255
+ %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+ store i32 %conv, i32* %arrayidx, align 4
+ %add = add nsw i32 %conv, %step
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
+
+; Here we check that the same phi scev analysis would fail
+; to create the runtime checks because the step is not invariant.
+; As a result vectorization will fail.
+;
+; int a[N];
+; void doit3(int n, int step) {
+; int i;
+; char p = 0;
+; for (i = 0; i < n; i++) {
+; a[i] = p;
+; p = p + step;
+; step += 2;
+; }
+; }
+;
+
+; CHECK-LABEL: @doit3
+; CHECK-NOT: vector.scevcheck
+; CHECK-NOT: vector.body:
+; CHECK-LABEL: for.body:
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit3(i32 %n, i32 %step) local_unnamed_addr {
+entry:
+ %cmp9 = icmp sgt i32 %n, 0
+ br i1 %cmp9, label %for.body.preheader, label %for.end
+
+for.body.preheader:
+ %wide.trip.count = zext i32 %n to i64
+ br label %for.body
+
+for.body:
+ %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+ %p.012 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+ %step.addr.010 = phi i32 [ %add3, %for.body ], [ %step, %for.body.preheader ]
+ %sext = shl i32 %p.012, 24
+ %conv = ashr exact i32 %sext, 24
+ %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+ store i32 %conv, i32* %arrayidx, align 4
+ %add = add nsw i32 %conv, %step.addr.010
+ %add3 = add nsw i32 %step.addr.010, 2
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
+
+
+; Lastly, we also check the case where we can tell at compile time that
+; the step of the induction is equal to sext(trunc(step)), in which case
+; we don't have to check this equality at runtime (we only need the
+; runtime overflow check). Therefore only the following overflow predicate
+; will be added for runtime checking:
+; {0,+,%cstep}<%for.body> Added Flags: <nssw>
+;
+; a[N];
+; void doit4(int n, char cstep) {
+; int i;
+; char p = 0;
+; int istep = cstep;
+; for (i = 0; i < n; i++) {
+; a[i] = p;
+; p = p + istep;
+; }
+; }
+
+; CHECK-LABEL: @doit4
+; CHECK: vector.scevcheck
+; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: %{{.*}} = or i1 {{.*}}, %mul.overflow
+; CHECK-NOT: %ident.check = icmp ne i32 {{.*}}, %{{.*}}
+; CHECK-NOT: %{{.*}} = or i1 %{{.*}}, %ident.check
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: vector.body:
+; CHECK: <4 x i32>
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit4(i32 %n, i8 signext %cstep) local_unnamed_addr {
+entry:
+ %conv = sext i8 %cstep to i32
+ %cmp10 = icmp sgt i32 %n, 0
+ br i1 %cmp10, label %for.body.preheader, label %for.end
+
+for.body.preheader:
+ %wide.trip.count = zext i32 %n to i64
+ br label %for.body
+
+for.body:
+ %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+ %p.011 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+ %sext = shl i32 %p.011, 24
+ %conv2 = ashr exact i32 %sext, 24
+ %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+ store i32 %conv2, i32* %arrayidx, align 4
+ %add = add nsw i32 %conv2, %conv
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+ br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
OpenPOWER on IntegriCloud