diff options
| -rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 94 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 2 | ||||
| -rw-r--r-- | llvm/test/Analysis/ScalarEvolution/pr34538.ll | 39 | 
3 files changed, 127 insertions, 8 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index d48e8a57562..7795a337e28 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -4080,6 +4080,85 @@ private:    bool Valid = true;  }; +/// This class evaluates the compare condition by matching it against the +/// condition of loop latch. If there is a match we assume a true value +/// for the condition while building SCEV nodes. +class SCEVBackedgeConditionFolder +    : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> { +public: +  static const SCEV *rewrite(const SCEV *S, const Loop *L, +                             ScalarEvolution &SE) { +    bool IsPosBECond; +    Value *BECond = nullptr; +    if (BasicBlock *Latch = L->getLoopLatch()) { +      BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator()); +      if (BI && BI->isConditional() && +          BI->getSuccessor(0) != BI->getSuccessor(1)) { +        BECond = BI->getCondition(); +        IsPosBECond = BI->getSuccessor(0) == L->getHeader(); +      } else { +        return S; +      } +    } +    SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE); +    return Rewriter.visit(S); +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    const SCEV *Result = Expr; +    bool InvariantF = SE.isLoopInvariant(Expr, L); + +    if (!InvariantF) { +      Instruction *I = cast<Instruction>(Expr->getValue()); +      switch (I->getOpcode()) { +      case Instruction::Select: { +        SelectInst *SI = cast<SelectInst>(I); +        Optional<const SCEV *> Res = +            compareWithBackedgeCondition(SI->getCondition()); +        if (Res.hasValue()) { +          bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne(); +          Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); +        } +        break; +      } +      default: { +        Optional<const SCEV *> Res = compareWithBackedgeCondition(I); +        if (Res.hasValue()) +          Result = Res.getValue(); +        break; +      } +      } +    } +    return Result; +  } + +private: +  explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond, +                                       bool IsPosBECond, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), +        IsPositiveBECond(IsPosBECond) {} + +  Optional<const SCEV *> compareWithBackedgeCondition(Value *IC); + +  const Loop *L; +  /// Loop back condition. +  Value *BackedgeCond = nullptr; +  /// Set to true if loop back is on positive branch condition. +  bool IsPositiveBECond; +}; + +Optional<const SCEV *> +SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { + +  // If value matches the backedge condition for loop latch, +  // then return a constant evolution node based on loopback +  // branch taken. +  if (BackedgeCond == IC) +    return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext())) +                            : SE.getZero(Type::getInt1Ty(SE.getContext())); +  return None; +} +  class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {  public:    SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) @@ -4753,7 +4832,8 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {        SmallVector<const SCEV *, 8> Ops;        for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)          if (i != FoundIndex) -          Ops.push_back(Add->getOperand(i)); +          Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), +                                                             L, *this));        const SCEV *Accum = getAddExpr(Ops);        // This is not a valid addrec if the step amount is varying each @@ -4779,33 +4859,33 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {            // indices form a positive value.            if (GEP->isInBounds() && GEP->getOperand(0) == PN) {              Flags = setFlags(Flags, SCEV::FlagNW); - +                const SCEV *Ptr = getSCEV(GEP->getPointerOperand());              if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))                Flags = setFlags(Flags, SCEV::FlagNUW);            } - +              // We cannot transfer nuw and nsw flags from subtraction            // operations -- sub nuw X, Y is not the same as add nuw X, -Y            // for instance.          } - +            const SCEV *StartVal = getSCEV(StartValueV);          const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - +            // Okay, for the entire analysis of this edge we assumed the PHI          // to be symbolic.  We now need to go back and purge all of the          // entries for the scalars that use the symbolic expression.          forgetSymbolicName(PN, SymbolicName);          ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - +            // We can add Flags to the post-inc expression only if we          // know that it is *undefined behavior* for BEValueV to          // overflow.          if (auto *BEInst = dyn_cast<Instruction>(BEValueV))            if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))              (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); - +            return PHISCEV;        }      } diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 7f03f2379e7..a161c839b8d 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -2970,7 +2970,7 @@ void LSRInstance::CollectChains() {        // consider leaf IV Users. This effectively rediscovers a portion of        // IVUsers analysis but in program order this time.        if (SE.isSCEVable(I.getType()) && !isa<SCEVUnknown>(SE.getSCEV(&I))) -        continue; +          continue;        // Remove this instruction from any NearUsers set it may be in.        for (unsigned ChainIdx = 0, NChains = IVChainVec.size(); diff --git a/llvm/test/Analysis/ScalarEvolution/pr34538.ll b/llvm/test/Analysis/ScalarEvolution/pr34538.ll new file mode 100644 index 00000000000..abef58e4968 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/pr34538.ll @@ -0,0 +1,39 @@ +; RUN: opt -scalar-evolution -loop-deletion -simplifycfg -analyze < %s | FileCheck %s --check-prefix=CHECK-ANALYSIS-1 +; RUN: opt -analyze -scalar-evolution < %s | FileCheck %s --check-prefix=CHECK-ANALYSIS-2 + +define i32 @pr34538() local_unnamed_addr #0 { +; CHECK-ANALYSIS-1: Loop %do.body: backedge-taken count is 10000 +; CHECK-ANALYSIS-1: Loop %do.body: max backedge-taken count is 10000 +; CHECK-ANALYSIS-1: Loop %do.body: Predicated backedge-taken count is 10000 +entry: +  br label %do.body + +do.body:                                          ; preds = %do.body, %entry +  %start.0 = phi i32 [ 0, %entry ], [ %inc.start.0, %do.body ] +  %cmp = icmp slt i32 %start.0, 10000 +  %inc = zext i1 %cmp to i32 +  %inc.start.0 = add nsw i32 %start.0, %inc +  br i1 %cmp, label %do.body, label %do.end + +do.end:                                           ; preds = %do.body +  ret i32 0 +} + + +define i32 @foo() { +entry: +  br label %do.body + +do.body:                                          ; preds = %do.body, %entry +  %start.0 = phi i32 [ 0, %entry ], [ %inc.start.0, %do.body ] +  %cmp = icmp slt i32 %start.0, 10000 +  %select_ext = select i1 %cmp, i32 2 , i32 1 +  %inc.start.0 = add nsw i32 %start.0, %select_ext +  br i1 %cmp, label %do.body, label %do.end + +do.end:                                           ; preds = %do.body +  ret i32 0 +; CHECK-ANALYSIS-2: Loop %do.body: backedge-taken count is 5000 +; CHECK-ANALYSIS-2: Loop %do.body: max backedge-taken count is 5000 +; CHECK-ANALYSIS-2: Loop %do.body: Predicated backedge-taken count is 5000 +}  | 

