diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 63 |
1 files changed, 29 insertions, 34 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 0a9f2787efb..b41662f1c11 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2198,53 +2198,48 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, Instruction *Sub, /// Fold icmp (add X, Y), C. Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, Instruction *Add, const APInt *C) { - // FIXME: This check restricts all folds under here to scalar types. - ConstantInt *RHS = dyn_cast<ConstantInt>(Cmp.getOperand(1)); - if (!RHS) - return nullptr; - - if (Cmp.isEquality()) + Value *Y = Add->getOperand(1); + const APInt *C2; + if (Cmp.isEquality() || !match(Y, m_APInt(C2))) return nullptr; - // Fold: icmp pred (add X, C2), C + // Fold icmp pred (add X, C2), C. Value *X = Add->getOperand(0); - ConstantInt *AddC = dyn_cast<ConstantInt>(Add->getOperand(1)); - if (!AddC) - return nullptr; - - const APInt &C2 = AddC->getValue(); - ConstantRange CR = Cmp.makeConstantRange(Cmp.getPredicate(), *C).subtract(C2); + Type *Ty = Add->getType(); + auto CR = Cmp.makeConstantRange(Cmp.getPredicate(), *C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { if (Lower.isSignBit()) - return new ICmpInst(ICmpInst::ICMP_SLT, X, Builder->getInt(Upper)); + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); if (Upper.isSignBit()) - return new ICmpInst(ICmpInst::ICMP_SGE, X, Builder->getInt(Lower)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); } else { if (Lower.isMinValue()) - return new ICmpInst(ICmpInst::ICMP_ULT, X, Builder->getInt(Upper)); + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper)); if (Upper.isMinValue()) - return new ICmpInst(ICmpInst::ICMP_UGE, X, Builder->getInt(Lower)); + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); } - if (Add->hasOneUse()) { - // X+C <u C2 -> (X & -C2) == C - // iff C & (C2-1) == 0 - // C2 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (C2 & (*C - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), - ConstantExpr::getNeg(AddC)); - - // X+C >u C2 -> (X & ~C2) != C - // iff C & C2 == 0 - // C2+1 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && - (C2 & *C) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), - ConstantExpr::getNeg(AddC)); - } + if (!Add->hasOneUse()) + return nullptr; + + // X+C <u C2 -> (X & -C2) == C + // iff C & (C2-1) == 0 + // C2 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); + + // X+C >u C2 -> (X & ~C2) != C + // iff C & C2 == 0 + // C2+1 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && + (*C2 & *C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); + return nullptr; } |

