diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 149 | 
1 files changed, 149 insertions, 0 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 9df31ed9e3a..e0658c21050 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -83,6 +83,7 @@  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Metadata.h"  #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/ErrorHandling.h" @@ -5398,6 +5399,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,          return ItCnt;      } +  ExitLimit ShiftEL = computeShiftCompareExitLimit( +      ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); +  if (ShiftEL.hasAnyInfo()) +    return ShiftEL; +    const SCEV *LHS = getSCEV(ExitCond->getOperand(0));    const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5576,6 +5582,149 @@ ScalarEvolution::computeLoadConstantCompareExitLimit(    return getCouldNotCompute();  } +ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( +    Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { +  ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV); +  if (!RHS) +    return getCouldNotCompute(); + +  const BasicBlock *Latch = L->getLoopLatch(); +  if (!Latch) +    return getCouldNotCompute(); + +  const BasicBlock *Predecessor = L->getLoopPredecessor(); +  if (!Predecessor) +    return getCouldNotCompute(); + +  // Return true if V is of the form "LHS `shift_op` <positive constant>". +  // Return LHS in OutLHS and shift_opt in OutOpCode. +  auto MatchPositiveShift = +      [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { + +    using namespace PatternMatch; + +    ConstantInt *ShiftAmt; +    if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) +      OutOpCode = Instruction::LShr; +    else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) +      OutOpCode = Instruction::AShr; +    else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) +      OutOpCode = Instruction::Shl; +    else +      return false; + +    return ShiftAmt->getValue().isStrictlyPositive(); +  }; + +  // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in +  // +  // loop: +  //   %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] +  //   %iv.shifted = lshr i32 %iv, <positive constant> +  // +  // Return true on a succesful match.  Return the corresponding PHI node (%iv +  // above) in PNOut and the opcode of the shift operation in OpCodeOut. +  auto MatchShiftRecurrence = +      [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { +    Optional<Instruction::BinaryOps> PostShiftOpCode; + +    { +      Instruction::BinaryOps OpC; +      Value *V; + +      // If we encounter a shift instruction, "peel off" the shift operation, +      // and remember that we did so.  Later when we inspect %iv's backedge +      // value, we will make sure that the backedge value uses the same +      // operation. +      // +      // Note: the peeled shift operation does not have to be the same +      // instruction as the one feeding into the PHI's backedge value.  We only +      // really care about it being the same *kind* of shift instruction -- +      // that's all that is required for our later inferences to hold. +      if (MatchPositiveShift(LHS, V, OpC)) { +        PostShiftOpCode = OpC; +        LHS = V; +      } +    } + +    PNOut = dyn_cast<PHINode>(LHS); +    if (!PNOut || PNOut->getParent() != L->getHeader()) +      return false; + +    Value *BEValue = PNOut->getIncomingValueForBlock(Latch); +    Value *OpLHS; + +    return +        // The backedge value for the PHI node must be a shift by a positive +        // amount +        MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && + +        // of the PHI node itself +        OpLHS == PNOut && + +        // and the kind of shift should be match the kind of shift we peeled +        // off, if any. +        (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); +  }; + +  PHINode *PN; +  Instruction::BinaryOps OpCode; +  if (!MatchShiftRecurrence(LHS, PN, OpCode)) +    return getCouldNotCompute(); + +  const DataLayout &DL = getDataLayout(); + +  // The key rationale for this optimization is that for some kinds of shift +  // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 +  // within a finite number of iterations.  If the condition guarding the +  // backedge (in the sense that the backedge is taken if the condition is true) +  // is false for the value the shift recurrence stabilizes to, then we know +  // that the backedge is taken only a finite number of times. + +  ConstantInt *StableValue = nullptr; +  switch (OpCode) { +  default: +    llvm_unreachable("Impossible case!"); + +  case Instruction::AShr: { +    // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most +    // bitwidth(K) iterations. +    Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); +    bool KnownZero, KnownOne; +    ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, +                   Predecessor->getTerminator(), &DT); +    auto *Ty = cast<IntegerType>(RHS->getType()); +    if (KnownZero) +      StableValue = ConstantInt::get(Ty, 0); +    else if (KnownOne) +      StableValue = ConstantInt::get(Ty, -1, true); +    else +      return getCouldNotCompute(); + +    break; +  } +  case Instruction::LShr: +  case Instruction::Shl: +    // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>} +    // stabilize to 0 in at most bitwidth(K) iterations. +    StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0); +    break; +  } + +  auto *Result = +      ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); +  assert(Result->getType()->isIntegerTy(1) && +         "Otherwise cannot be an operand to a branch instruction"); + +  if (Result->isZeroValue()) { +    unsigned BitWidth = getTypeSizeInBits(RHS->getType()); +    const SCEV *UpperBound = +        getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); +    return ExitLimit(getCouldNotCompute(), UpperBound); +  } + +  return getCouldNotCompute(); +}  /// CanConstantFold - Return true if we can constant fold an instruction of the  /// specified type, assuming that all operands were constants.  | 

