diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 5e71c5c4b7c..3a501132ebd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1044,6 +1044,90 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, return nullptr; } +/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// (icmp eq/ne A, Log2(const2/const1)) -> +/// (icmp eq/ne A, Log2(const2) - Log2(const1)). +Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + if (!AP1) { + if (!AP2) { + // Both Constants are 0. + return getConstant(true); + } + + if (cast<BinaryOperator>(Op)->isExact()) + return getConstant(false); + + if (AP2.isNegative()) { + // MSB is set, so a lshr with a large enough 'A' would be undefined. + return getConstant(false); + } + + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + } + + if (!AP2) { + // Shifting 0 by any value gives 0. + return getConstant(false); + } + + bool IsAShr = isa<AShrOperator>(Op); + if (AP1 == AP2) { + if (AP1.isAllOnesValue() && IsAShr) { + // Arithmatic shift of -1 is always -1. + return getConstant(true); + } + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + } + + if (IsAShr) { + if (AP1.isNegative() != AP2.isNegative()) { + // Arithmetic shift will never change the sign. + return getConstant(false); + } + // Both the constants are negative, take their positive to calculate + // log. + if (AP1.isNegative()) { + AP1 = -AP1; + AP2 = -AP2; + } + } + + if (AP1.ugt(AP2)) { + // Right-shifting will not increase the value. + return getConstant(false); + } + + // Get the distance between the highest bit that's set. + int Shift = AP2.logBase2() - AP1.logBase2(); + + // Use lshr here, since we've canonicalized to +ve numbers. + if (AP1 == AP2.lshr(Shift)) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -2469,6 +2553,15 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Builder->getInt(CI->getValue()-1)); } + // (icmp eq/ne (ashr/lshr const2, A), const1) + if (I.isEquality()) { + ConstantInt *CI2; + if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || + match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { + return FoldICmpCstShrCst(I, Op0, A, CI, CI2); + } + } + // If this comparison is a normal comparison, it demands all // bits, if it is a sign bit comparison, it only demands the sign bit. bool UnusedBit; |