diff options
author | Sanjay Patel <spatel@rotateright.com> | 2016-09-12 15:24:31 +0000 |
---|---|---|
committer | Sanjay Patel <spatel@rotateright.com> | 2016-09-12 15:24:31 +0000 |
commit | 3151dec7f1d61816101ff8bb29c35610eb26684a (patch) | |
tree | 033b0ac52e18be3c5e753ac846ed14cff4a3a86c /llvm/lib/Transforms | |
parent | fb0d9d9c13be4e54489f0ebbe18c9626c5b63e9c (diff) | |
download | bcm5719-llvm-3151dec7f1d61816101ff8bb29c35610eb26684a.tar.gz bcm5719-llvm-3151dec7f1d61816101ff8bb29c35610eb26684a.zip |
[InstCombine] add helper function for foldICmpUsingKnownBits; NFCI
llvm-svn: 281217
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 537 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 1 |
2 files changed, 279 insertions, 259 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index a28809d4bb0..977d3cbad5c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3135,6 +3135,274 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, return false; } +/// Try to fold the comparison based on range information we can get by checking +/// whether bits are known to be zero or one in the inputs. +Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = Op0->getType(); + + // Get scalar or pointer size. + unsigned BitWidth = Ty->isIntOrIntVectorTy() + ? Ty->getScalarSizeInBits() + : DL.getTypeSizeInBits(Ty->getScalarType()); + + if (!BitWidth) + return nullptr; + + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool IsSignBit = false; + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + bool UnusedBit; + IsSignBit = isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit); + } + + APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); + APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + + if (SimplifyDemandedBits(I.getOperandUse(0), + DemandedBitsLHSMask(I, BitWidth, IsSignBit), + Op0KnownZero, Op0KnownOne, 0)) + return &I; + + if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), + Op1KnownZero, Op1KnownOne, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (I.isSigned()) { + ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } else { + ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } + + // If Min and Max are known to be the same, then SimplifyDemandedBits + // figured out that the LHS is a constant. Just constant fold this now so + // that code below can assume that Min != Max. + if (!isa<Constant>(Op0) && Op0Min == Op0Max) + return new ICmpInst(I.getPredicate(), + ConstantInt::get(Op0->getType(), Op0Min), Op1); + if (!isa<Constant>(Op1) && Op1Min == Op1Max) + return new ICmpInst(I.getPredicate(), Op0, + ConstantInt::get(Op1->getType(), Op1Min)); + + // Based on the range information we know about the LHS, see if we can + // simplify this comparison. For example, (x&4) < 8 is always true. + switch (I.getPredicate()) { + default: + llvm_unreachable("Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: { + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + // If all bits are known zero except for one, then we know at most one + // bit is set. If the comparison is against zero, then this is a check + // to see if *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = nullptr; + ConstantInt *LHSC = nullptr; + if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || + LHSC->getValue() != Op0KnownZeroInverted) + LHS = Op0; + + // If the LHS is 1 << x, and we know the result is a power of 2 like 8, + // then turn "((1 << x)&8) == 0" into "x != 3". + // or turn "((1 << x)&7) == 0" into "x > 2". + Value *X = nullptr; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + APInt ValToCheck = Op0KnownZeroInverted; + if (ValToCheck.isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), CmpVal)); + } else if ((++ValToCheck).isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros() - 1; + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), CmpVal)); + } + } + + // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, + // then turn "((8 >>u x)&1) == 0" into "x != 3". + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) + return new ICmpInst( + ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), CI->countTrailingZeros())); + } + break; + } + case ICmpInst::ICMP_NE: { + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + // If all bits are known zero except for one, then we know at most one + // bit is set. If the comparison is against zero, then this is a check + // to see if *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = nullptr; + ConstantInt *LHSC = nullptr; + if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || + LHSC->getValue() != Op0KnownZeroInverted) + LHS = Op0; + + // If the LHS is 1 << x, and we know the result is a power of 2 like 8, + // then turn "((1 << x)&8) != 0" into "x == 3". + // or turn "((1 << x)&7) != 0" into "x < 3". + Value *X = nullptr; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + APInt ValToCheck = Op0KnownZeroInverted; + if (ValToCheck.isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), CmpVal)); + } else if ((++ValToCheck).isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_ULT, X, + ConstantInt::get(X->getType(), CmpVal)); + } + } + + // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, + // then turn "((8 >>u x)&1) != 0" into "x == 3". + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) + return new ICmpInst( + ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), CI->countTrailingZeros())); + } + break; + } + case ICmpInst::ICMP_ULT: { + if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A <u C -> A == C-1 if min(A)+1 == C + if (Op1Max == Op0Min + 1) { + Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); + } + // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear + if (CmpC->isMinSignedValue()) { + Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); + } + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + + // (x >u 2147483647) -> (x <s 0) -> true if sign bit set + if (CmpC->isMaxSignedValue()) + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, + Constant::getNullValue(Op0->getType())); + } + break; + } + case ICmpInst::ICMP_SLT: + if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() - 1)); + } + break; + case ICmpInst::ICMP_SGT: + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() + 1)); + } + break; + case ICmpInst::ICMP_SGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_SLE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_UGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_ULE: + assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + + // Turn a signed comparison into an unsigned one if both operands are known to + // have the same sign. + if (I.isSigned() && + ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || + (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); + + return nullptr; +} + /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. @@ -3276,14 +3544,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; - unsigned BitWidth = 0; - if (Ty->isIntOrIntVectorTy()) - BitWidth = Ty->getScalarSizeInBits(); - else // Get pointer size. - BitWidth = DL.getTypeSizeInBits(Ty->getScalarType()); - - bool isSignBit = false; - // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { Value *A = nullptr, *B = nullptr; @@ -3365,11 +3625,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } - // If this comparison is a normal comparison, it demands all - // bits, if it is a sign bit comparison, it only demands the sign bit. - bool UnusedBit; - isSignBit = isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit); - // Canonicalize icmp instructions based on dominating conditions. BasicBlock *Parent = I.getParent(); BasicBlock *Dom = Parent->getSinglePredecessor(); @@ -3393,12 +3648,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return replaceInstUsesWith(I, Builder->getFalse()); if (Difference.isEmptySet()) return replaceInstUsesWith(I, Builder->getTrue()); + + // If this is a normal comparison, it demands all bits. If it is a sign + // bit comparison, it only demands the sign bit. + bool UnusedBit; + bool IsSignBit = + isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit); + // Canonicalizing a sign bit comparison that gets used in a branch, // pessimizes codegen by generating branch on zero instruction instead // of a test and branch. So we avoid canonicalizing in such situations // because test and branch instruction has better branch displacement // than compare and branch instruction. - if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) { + if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) { if (auto *AI = Intersection.getSingleElement()) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); if (auto *AD = Difference.getSingleElement()) @@ -3407,251 +3669,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } - // See if we can fold the comparison based on range information we can get - // by checking whether bits are known to be zero or one in the input. - if (BitWidth != 0) { - APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); - APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); - - if (SimplifyDemandedBits(I.getOperandUse(0), - DemandedBitsLHSMask(I, BitWidth, isSignBit), - Op0KnownZero, Op0KnownOne, 0)) - return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), - APInt::getAllOnesValue(BitWidth), Op1KnownZero, - Op1KnownOne, 0)) - return &I; - - // Given the known and unknown bits, compute a range that the LHS could be - // in. Compute the Min, Max and RHS values based on the known bits. For the - // EQ and NE we use unsigned values. - APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); - APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); - if (I.isSigned()) { - ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } else { - ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } - - // If Min and Max are known to be the same, then SimplifyDemandedBits - // figured out that the LHS is a constant. Just constant fold this now so - // that code below can assume that Min != Max. - if (!isa<Constant>(Op0) && Op0Min == Op0Max) - return new ICmpInst(I.getPredicate(), - ConstantInt::get(Op0->getType(), Op0Min), Op1); - if (!isa<Constant>(Op1) && Op1Min == Op1Max) - return new ICmpInst(I.getPredicate(), Op0, - ConstantInt::get(Op1->getType(), Op1Min)); - - // Based on the range information we know about the LHS, see if we can - // simplify this comparison. For example, (x&4) < 8 is always true. - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown icmp opcode!"); - case ICmpInst::ICMP_EQ: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) == 0" into "x != 3". - // or turn "((1 << x)&7) == 0" into "x > 2". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros() - 1; - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) == 0" into "x != 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_NE: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) != 0" into "x == 3". - // or turn "((1 << x)&7) != 0" into "x < 3". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_ULT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) != 0" into "x == 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_ULT: { - if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - // A <u C -> A == C-1 if min(A)+1 == C - if (Op1Max == Op0Min + 1) { - Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); - } - // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear - if (CmpC->isMinSignedValue()) { - Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); - } - } - break; - } - case ICmpInst::ICMP_UGT: { - if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - // A >u C -> A == C+1 if max(a)-1 == C - if (*CmpC == Op0Max - 1) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC + 1)); - - // (x >u 2147483647) -> (x <s 0) -> true if sign bit set - if (CmpC->isMaxSignedValue()) - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - Constant::getNullValue(Op0->getType())); - } - break; - } - case ICmpInst::ICMP_SLT: - if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - } - break; - case ICmpInst::ICMP_SGT: - if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - } - break; - case ICmpInst::ICMP_SGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); - if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_SLE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); - if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_UGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); - if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_ULE: - assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); - if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } - - // Turn a signed comparison into an unsigned one if both operands - // are known to have the same sign. - if (I.isSigned() && - ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || - (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) - return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); - } + if (Instruction *Res = foldICmpUsingKnownBits(I)) + return Res; // Test if the ICmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index bf0f4b2efa1..00acab2cbec 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -556,6 +556,7 @@ private: ICmpInst::Predicate Pred); Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); + Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, |