diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 76 |
1 files changed, 43 insertions, 33 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index dfa922b7480..ad2bd1841f1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1532,50 +1532,60 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, if (Instruction *I = foldICmpAndShift(Cmp, And, C1)) return I; + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast<ConstantInt>(Cmp.getOperand(1)); + if (!RHS) + return nullptr; + // (icmp pred (and (or (lshr A, B), A), 1), 0) --> // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && *C1 == 0 && match(And->getOperand(1), m_One())) { - Constant *One = cast<Constant>(And->getOperand(1)); - Value *Or = And->getOperand(0); + { Value *A, *B, *LShr; - if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && - match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { - unsigned UsesRemoved = 0; - if (And->hasOneUse()) - ++UsesRemoved; - if (Or->hasOneUse()) - ++UsesRemoved; - if (LShr->hasOneUse()) - ++UsesRemoved; - - // Compute A & ((1 << B) | 1) - Value *NewOr = nullptr; - if (auto *C = dyn_cast<Constant>(B)) { - if (UsesRemoved >= 1) - NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder->CreateOr(Builder->CreateShl(One, B, LShr->getName(), + if (!Cmp.isSigned() && *C1 == 0) { + if (match(And->getOperand(1), m_One())) { + Constant *One = cast<Constant>(And->getOperand(1)); + Value *Or = And->getOperand(0); + if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && + match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { + unsigned UsesRemoved = 0; + if (And->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + Value *NewOr = nullptr; + // Compute A & ((1 << B) | 1) + if (auto *C = dyn_cast<Constant>(B)) { + if (UsesRemoved >= 1) + NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = + Builder->CreateOr(Builder->CreateShl(One, B, LShr->getName(), /*HasNUW=*/true), One, Or->getName()); - } - if (NewOr) { - Value *NewAnd = Builder->CreateAnd(A, NewOr, And->getName()); - Cmp.setOperand(0, NewAnd); - return &Cmp; + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(A, NewOr, And->getName()); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + } } } } - // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a - // result greater than C1. - unsigned NumTZ = C2->countTrailingZeros(); - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && - APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { - Constant *Zero = Constant::getNullValue(And->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + // Replace ((X & C2) > C1) with ((X & C2) != 0), if any bit set in (X & C2) + // will produce a result greater than C1. + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { + unsigned NTZ = C2->countTrailingZeros(); + if ((NTZ < C2->getBitWidth()) && + APInt::getOneBitSet(C2->getBitWidth(), NTZ).ugt(*C1)) + return new ICmpInst(ICmpInst::ICMP_NE, And, + Constant::getNullValue(RHS->getType())); } return nullptr; |