diff options
Diffstat (limited to 'llvm/lib/Transforms')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 190 | 
1 files changed, 96 insertions, 94 deletions
| diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index e32cacc4412..0d6df2fbae3 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -2133,15 +2133,12 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,      return Cond;    const SCEVConstant *SC = dyn_cast<SCEVConstant>(*CondStride);    if (!SC) return Cond; -  ConstantInt *C = dyn_cast<ConstantInt>(Cond->getOperand(1)); -  if (!C) return Cond;    ICmpInst::Predicate Predicate = Cond->getPredicate();    int64_t CmpSSInt = SC->getValue()->getSExtValue(); -  int64_t CmpVal = C->getValue().getSExtValue(); -  unsigned BitWidth = C->getValue().getBitWidth(); +  unsigned BitWidth = (*CondStride)->getBitWidth();    uint64_t SignBit = 1ULL << (BitWidth-1); -  const Type *CmpTy = C->getType(); +  const Type *CmpTy = Cond->getOperand(0)->getType();    const Type *NewCmpTy = NULL;    unsigned TyBits = CmpTy->getPrimitiveSizeInBits();    unsigned NewTyBits = 0; @@ -2149,102 +2146,112 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,    Value *NewCmpLHS = NULL;    Value *NewCmpRHS = NULL;    int64_t Scale = 1; +  SCEVHandle NewOffset = SE->getIntegerSCEV(0, UIntPtrTy); +  std::stable_sort(StrideOrder.begin(), StrideOrder.end(), StrideCompare()); -  // Check stride constant and the comparision constant signs to detect -  // overflow. -  if ((CmpVal & SignBit) != (CmpSSInt & SignBit)) -    return Cond; +  if (ConstantInt *C = dyn_cast<ConstantInt>(Cond->getOperand(1))) { +    int64_t CmpVal = C->getValue().getSExtValue(); -  // Look for a suitable stride / iv as replacement. -  std::stable_sort(StrideOrder.begin(), StrideOrder.end(), StrideCompare()); -  for (unsigned i = 0, e = StrideOrder.size(); i != e; ++i) { -    std::map<SCEVHandle, IVUsersOfOneStride>::iterator SI =  -      IVUsesByStride.find(StrideOrder[i]); -    if (!isa<SCEVConstant>(SI->first)) -      continue; -    int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue(); -    if (abs(SSInt) <= abs(CmpSSInt) || (SSInt % CmpSSInt) != 0) -      continue; +    // Check stride constant and the comparision constant signs to detect +    // overflow. +    if ((CmpVal & SignBit) != (CmpSSInt & SignBit)) +      return Cond; -    Scale = SSInt / CmpSSInt; -    int64_t NewCmpVal = CmpVal * Scale; -    APInt Mul = APInt(BitWidth, NewCmpVal); -    // Check for overflow. -    if (Mul.getSExtValue() != NewCmpVal) -      continue; +    // Look for a suitable stride / iv as replacement. +    for (unsigned i = 0, e = StrideOrder.size(); i != e; ++i) { +      std::map<SCEVHandle, IVUsersOfOneStride>::iterator SI =  +        IVUsesByStride.find(StrideOrder[i]); +      if (!isa<SCEVConstant>(SI->first)) +        continue; +      int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue(); +      if (abs(SSInt) <= abs(CmpSSInt) || (SSInt % CmpSSInt) != 0) +        continue; -    // Watch out for overflow. -    if (ICmpInst::isSignedPredicate(Predicate) && -        (CmpVal & SignBit) != (NewCmpVal & SignBit)) -      continue; +      Scale = SSInt / CmpSSInt; +      int64_t NewCmpVal = CmpVal * Scale; +      APInt Mul = APInt(BitWidth, NewCmpVal); +      // Check for overflow. +      if (Mul.getSExtValue() != NewCmpVal) +        continue; -    if (NewCmpVal == CmpVal) -      continue; -    // Pick the best iv to use trying to avoid a cast. -    NewCmpLHS = NULL; -    for (std::vector<IVStrideUse>::iterator UI = SI->second.Users.begin(), -           E = SI->second.Users.end(); UI != E; ++UI) { -      NewCmpLHS = UI->OperandValToReplace; -      if (NewCmpLHS->getType() == CmpTy) -        break; -    } -    if (!NewCmpLHS) -      continue; +      // Watch out for overflow. +      if (ICmpInst::isSignedPredicate(Predicate) && +          (CmpVal & SignBit) != (NewCmpVal & SignBit)) +        continue; -    NewCmpTy = NewCmpLHS->getType(); -    NewTyBits = isa<PointerType>(NewCmpTy) -      ? UIntPtrTy->getPrimitiveSizeInBits() -      : NewCmpTy->getPrimitiveSizeInBits(); -    if (RequiresTypeConversion(NewCmpTy, CmpTy)) { -      // Check if it is possible to rewrite it using -      // an iv / stride of a smaller integer type. -      bool TruncOk = false; -      if (NewCmpTy->isInteger()) { -        unsigned Bits = NewTyBits; -        if (ICmpInst::isSignedPredicate(Predicate)) -          --Bits; -        uint64_t Mask = (1ULL << Bits) - 1; -        if (((uint64_t)NewCmpVal & Mask) == (uint64_t)NewCmpVal) -          TruncOk = true; +      if (NewCmpVal == CmpVal) +        continue; +      // Pick the best iv to use trying to avoid a cast. +      NewCmpLHS = NULL; +      for (std::vector<IVStrideUse>::iterator UI = SI->second.Users.begin(), +             E = SI->second.Users.end(); UI != E; ++UI) { +        NewCmpLHS = UI->OperandValToReplace; +        if (NewCmpLHS->getType() == CmpTy) +          break;        } -      if (!TruncOk) +      if (!NewCmpLHS)          continue; -    } -    // Don't rewrite if use offset is non-constant and the new type is -    // of a different type. -    // FIXME: too conservative? -    if (NewTyBits != TyBits && !isa<SCEVConstant>(CondUse->Offset)) -      continue; +      NewCmpTy = NewCmpLHS->getType(); +      NewTyBits = isa<PointerType>(NewCmpTy) +        ? UIntPtrTy->getPrimitiveSizeInBits() +        : NewCmpTy->getPrimitiveSizeInBits(); +      if (RequiresTypeConversion(NewCmpTy, CmpTy)) { +        // Check if it is possible to rewrite it using +        // an iv / stride of a smaller integer type. +        bool TruncOk = false; +        if (NewCmpTy->isInteger()) { +          unsigned Bits = NewTyBits; +          if (ICmpInst::isSignedPredicate(Predicate)) +            --Bits; +          uint64_t Mask = (1ULL << Bits) - 1; +          if (((uint64_t)NewCmpVal & Mask) == (uint64_t)NewCmpVal) +            TruncOk = true; +        } +        if (!TruncOk) +          continue; +      } -    bool AllUsesAreAddresses = true; -    bool AllUsesAreOutsideLoop = true; -    std::vector<BasedUser> UsersToProcess; -    SCEVHandle CommonExprs = CollectIVUsers(SI->first, SI->second, L, -                                            AllUsesAreAddresses, -                                            AllUsesAreOutsideLoop, -                                            UsersToProcess); -    // Avoid rewriting the compare instruction with an iv of new stride -    // if it's likely the new stride uses will be rewritten using the -    // stride of the compare instruction. -    if (AllUsesAreAddresses && -        ValidStride(!CommonExprs->isZero(), Scale, UsersToProcess)) -      continue; +      // Don't rewrite if use offset is non-constant and the new type is +      // of a different type. +      // FIXME: too conservative? +      if (NewTyBits != TyBits && !isa<SCEVConstant>(CondUse->Offset)) +        continue; -    // If scale is negative, use swapped predicate unless it's testing -    // for equality. -    if (Scale < 0 && !Cond->isEquality()) -      Predicate = ICmpInst::getSwappedPredicate(Predicate); - -    NewStride = &StrideOrder[i]; -    if (!isa<PointerType>(NewCmpTy)) -      NewCmpRHS = ConstantInt::get(NewCmpTy, NewCmpVal); -    else { -      NewCmpRHS = ConstantInt::get(UIntPtrTy, NewCmpVal); -      NewCmpRHS = SCEVExpander::InsertCastOfTo(Instruction::IntToPtr, -                                               NewCmpRHS, NewCmpTy); +      bool AllUsesAreAddresses = true; +      bool AllUsesAreOutsideLoop = true; +      std::vector<BasedUser> UsersToProcess; +      SCEVHandle CommonExprs = CollectIVUsers(SI->first, SI->second, L, +                                              AllUsesAreAddresses, +                                              AllUsesAreOutsideLoop, +                                              UsersToProcess); +      // Avoid rewriting the compare instruction with an iv of new stride +      // if it's likely the new stride uses will be rewritten using the +      // stride of the compare instruction. +      if (AllUsesAreAddresses && +          ValidStride(!CommonExprs->isZero(), Scale, UsersToProcess)) +        continue; + +      // If scale is negative, use swapped predicate unless it's testing +      // for equality. +      if (Scale < 0 && !Cond->isEquality()) +        Predicate = ICmpInst::getSwappedPredicate(Predicate); + +      NewStride = &StrideOrder[i]; +      if (!isa<PointerType>(NewCmpTy)) +        NewCmpRHS = ConstantInt::get(NewCmpTy, NewCmpVal); +      else { +        NewCmpRHS = ConstantInt::get(UIntPtrTy, NewCmpVal); +        NewCmpRHS = SCEVExpander::InsertCastOfTo(Instruction::IntToPtr, +                                                 NewCmpRHS, NewCmpTy); +      } +      NewOffset = TyBits == NewTyBits +        ? SE->getMulExpr(CondUse->Offset, +                         SE->getConstant(ConstantInt::get(CmpTy, Scale))) +        : SE->getConstant(ConstantInt::get(NewCmpTy, +          cast<SCEVConstant>(CondUse->Offset)->getValue()->getSExtValue()*Scale)); +      break;      } -    break;    }    // Forgo this transformation if it the increment happens to be @@ -2275,11 +2282,6 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,      OldCond->eraseFromParent();      IVUsesByStride[*CondStride].Users.pop_back(); -    SCEVHandle NewOffset = TyBits == NewTyBits -      ? SE->getMulExpr(CondUse->Offset, -                       SE->getConstant(ConstantInt::get(CmpTy, Scale))) -      : SE->getConstant(ConstantInt::get(NewCmpTy, -        cast<SCEVConstant>(CondUse->Offset)->getValue()->getSExtValue()*Scale));      IVUsesByStride[*NewStride].addUser(NewOffset, Cond, NewCmpLHS);      CondUse = &IVUsesByStride[*NewStride].Users.back();      CondStride = NewStride; | 

