diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 92 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 3 |
2 files changed, 46 insertions, 49 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 847b4a19995..95dab355e1f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1530,15 +1530,22 @@ Instruction *InstCombiner::foldICmpCstShlConst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } -/// Handle "icmp (instr, intcst)". -Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, - Instruction *LHSI, - ConstantInt *RHS) { - const APInt &RHSV = RHS->getValue(); +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C. +Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI) { + Instruction *LHSI; + const APInt *RHSV; + if (!match(ICI.getOperand(0), m_Instruction(LHSI)) || + !match(ICI.getOperand(1), m_APInt(RHSV))) + return nullptr; + + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast<ConstantInt>(ICI.getOperand(1)); + if (!RHS) + return nullptr; switch (LHSI->getOpcode()) { case Instruction::Trunc: - if (RHS->isOne() && RHSV.getBitWidth() > 1) { + if (RHS->isOne() && RHSV->getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (ICI.getPredicate() == ICmpInst::ICMP_SLT && @@ -1569,8 +1576,8 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. - if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || - (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { + if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && *RHSV == 0) || + (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV->isAllOnesValue())) { Value *CompareVal = LHSI->getOperand(0); // If the sign bit of the XorCst is not set, there is no change to @@ -1603,7 +1610,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, ? ICI.getUnsignedPredicate() : ICI.getSignedPredicate(); return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ SignBit)); + Builder->getInt(*RHSV ^ SignBit)); } // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) @@ -1614,20 +1621,20 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, : ICI.getSignedPredicate(); Pred = ICI.getSwappedPredicate(Pred); return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ NotSignBit)); + Builder->getInt(*RHSV ^ NotSignBit)); } } // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) // iff -C is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && - XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) + XorCst->getValue() == ~(*RHSV) && (*RHSV + 1).isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); // (icmp ult (xor X, C), -C) -> (icmp uge X, C) // iff -C is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_ULT && - XorCst->getValue() == -RHSV && RHSV.isPowerOf2()) + XorCst->getValue() == -(*RHSV) && RHSV->isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); } break; @@ -1645,7 +1652,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // Extending a relational comparison when we're checking the sign // bit would not work. if (ICI.isEquality() || - (!AndCst->isNegative() && RHSV.isNonNegative())) { + (!AndCst->isNegative() && RHSV->isNonNegative())) { Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); @@ -1661,7 +1668,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, IntegerType *Ty = cast<IntegerType>(Cast->getSrcTy()); // Make sure we don't compare the upper bits, SimplifyDemandedBits // should fold the icmp to true/false in that case. - if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) { + if (ICI.isEquality() && RHSV->getActiveBits() <= Ty->getBitWidth()) { Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), ConstantExpr::getTrunc(AndCst, Ty)); @@ -1754,7 +1761,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is // preferable because it allows the C<<Y expression to be hoisted out // of a loop if Y is invariant and X is not. - if (Shift && Shift->hasOneUse() && RHSV == 0 && + if (Shift && Shift->hasOneUse() && *RHSV == 0 && ICI.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C << Y. @@ -1780,7 +1787,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // iff pred isn't signed { Value *X, *Y, *LShr; - if (!ICI.isSigned() && RHSV == 0) { + if (!ICI.isSigned() && *RHSV == 0) { if (match(LHSI->getOperand(1), m_One())) { Constant *One = cast<Constant>(LHSI->getOperand(1)); Value *Or = LHSI->getOperand(0); @@ -1821,7 +1828,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { unsigned NTZ = AndCst->getValue().countTrailingZeros(); if ((NTZ < AndCst->getBitWidth()) && - APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV)) + APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(*RHSV)) return new ICmpInst(ICmpInst::ICMP_NE, LHSI, Constant::getNullValue(RHS->getType())); } @@ -1843,7 +1850,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2()) + if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-(*RHSV)).isPowerOf2()) return new ICmpInst( ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE, @@ -1915,13 +1922,13 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, } case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) - uint32_t TypeBits = RHSV.getBitWidth(); + uint32_t TypeBits = RHSV->getBitWidth(); ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); if (!ShAmt) { Value *X; // (1 << X) pred P2 -> X pred Log2(P2) if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { - bool RHSVIsPowerOf2 = RHSV.isPowerOf2(); + bool RHSVIsPowerOf2 = RHSV->isPowerOf2(); ICmpInst::Predicate Pred = ICI.getPredicate(); if (ICI.isUnsigned()) { if (!RHSVIsPowerOf2) { @@ -1934,7 +1941,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, else if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_UGT; } - unsigned RHSLog2 = RHSV.logBase2(); + unsigned RHSLog2 = RHSV->logBase2(); // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 // (1 << X) < 2147483648 -> X < 31 -> X != 31 @@ -1948,7 +1955,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, return new ICmpInst(Pred, X, ConstantInt::get(RHS->getType(), RHSLog2)); } else if (ICI.isSigned()) { - if (RHSV.isAllOnesValue()) { + if (RHSV->isAllOnesValue()) { // (1 << X) <= -1 -> X == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, X, @@ -1958,7 +1965,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, if (Pred == ICmpInst::ICMP_SGT) return new ICmpInst(ICmpInst::ICMP_NE, X, ConstantInt::get(RHS->getType(), TypeBits-1)); - } else if (!RHSV) { + } else if (!(*RHSV)) { // (1 << X) < 0 -> X == 31 // (1 << X) <= 0 -> X == 31 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) @@ -1974,7 +1981,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, } else if (ICI.isEquality()) { if (RHSVIsPowerOf2) return new ICmpInst( - Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); + Pred, X, ConstantInt::get(RHS->getType(), RHSV->logBase2())); } } break; @@ -2006,7 +2013,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // If the shift is NSW and we compare to 0, then it is just shifting out // sign bits, no need for an AND either. - if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0) + if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && *RHSV == 0) return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), ConstantExpr::getLShr(RHS, ShAmt)); @@ -2054,7 +2061,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // smaller constant, which will be target friendly. unsigned Amt = ShAmt->getLimitedValue(TypeBits-1); if (LHSI->hasOneUse() && - Amt != 0 && RHSV.countTrailingZeros() >= Amt) { + Amt != 0 && RHSV->countTrailingZeros() >= Amt) { Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt); Constant *NCI = ConstantExpr::getTrunc( ConstantExpr::getAShr(RHS, @@ -2079,7 +2086,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // Handle exact shr's. if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { - if (RHSV.isMinValue()) + if (RHSV->isMinValue()) return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); } break; @@ -2128,18 +2135,18 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // iff C1 & (C2-1) == C2-1 // C2 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1)) + RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == (*RHSV - 1)) return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateOr(LHSI->getOperand(1), RHSV - 1), + Builder->CreateOr(LHSI->getOperand(1), *RHSV - 1), LHSC); // C1-X >u C2 -> (X|C2) != C1 // iff C1 & C2 == C2 // C2+1 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV) + (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == *RHSV) return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC); + Builder->CreateOr(LHSI->getOperand(1), *RHSV), LHSC); break; } @@ -2150,7 +2157,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, if (!LHSC) break; const APInt &LHSV = LHSC->getValue(); - ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV) + ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), *RHSV) .subtract(LHSV); if (ICI.isSigned()) { @@ -2175,18 +2182,18 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI, // iff C1 & (C2-1) == 0 // C2 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0) + RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == 0) return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateAnd(LHSI->getOperand(0), -RHSV), + Builder->CreateAnd(LHSI->getOperand(0), -(*RHSV)), ConstantExpr::getNeg(LHSC)); // X-C1 >u C2 -> (X & ~C2) != C1 // iff C1 & C2 == 0 // C2+1 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0) + (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == 0) return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateAnd(LHSI->getOperand(0), ~RHSV), + Builder->CreateAnd(LHSI->getOperand(0), ~(*RHSV)), ConstantExpr::getNeg(LHSC)); } break; @@ -3627,17 +3634,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // See if we are doing a comparison between a constant and an instruction that // can be folded into the comparison. - // FIXME: Use m_APInt instead of dyn_cast<ConstantInt> to allow these - // transforms for vectors. - - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - // Since the RHS is a ConstantInt (CI), if the left hand side is an - // instruction, see if that instruction also has constants so that the - // instruction can be folded into the icmp - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - if (Instruction *Res = foldICmpWithConstant(I, LHSI, CI)) - return Res; - } + if (Instruction *Res = foldICmpWithConstant(I)) + return Res; if (Instruction *Res = foldICmpEqualityWithConstant(I)) return Res; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index c555ff8d129..88c52e2413d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -559,8 +559,7 @@ private: Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); - Instruction *foldICmpWithConstant(ICmpInst &ICI, Instruction *LHS, - ConstantInt *RHS); + Instruction *foldICmpWithConstant(ICmpInst &ICI); Instruction *foldICmpEqualityWithConstant(ICmpInst &ICI); Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI); |