diff options
-rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 276 |
1 files changed, 143 insertions, 133 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index f79cfadcc66..5b7805869ea 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -2152,6 +2152,147 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, return nullptr; } +static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, + Value *RHS) { + // FIXME: Use m_APInt here and below to allow splat vector folds. + ConstantInt *CI = dyn_cast<ConstantInt>(RHS); + if (!CI) + return nullptr; + + // Rule out tautological comparisons (eg., ult 0 or uge 0). + ConstantRange RHS_CR = ICmpInst::makeConstantRange(Pred, CI->getValue()); + if (RHS_CR.isEmptySet()) + return ConstantInt::getFalse(CI->getContext()); + if (RHS_CR.isFullSet()) + return ConstantInt::getTrue(CI->getContext()); + + // Many binary operators with constant RHS have easy to compute constant + // range. Use them to check whether the comparison is a tautology. + unsigned Width = CI->getBitWidth(); + APInt Lower = APInt(Width, 0); + APInt Upper = APInt(Width, 0); + ConstantInt *CI2; + if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) { + // 'urem x, CI2' produces [0, CI2). + Upper = CI2->getValue(); + } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) { + // 'srem x, CI2' produces (-|CI2|, |CI2|). + Upper = CI2->getValue().abs(); + Lower = (-Upper) + 1; + } else if (match(LHS, m_UDiv(m_ConstantInt(CI2), m_Value()))) { + // 'udiv CI2, x' produces [0, CI2]. + Upper = CI2->getValue() + 1; + } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) { + // 'udiv x, CI2' produces [0, UINT_MAX / CI2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (!CI2->isZero()) + Upper = NegOne.udiv(CI2->getValue()) + 1; + } else if (match(LHS, m_SDiv(m_ConstantInt(CI2), m_Value()))) { + if (CI2->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = CI2->getValue(); + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv CI2, x' produces [-|CI2|, |CI2|]. + Upper = CI2->getValue().abs() + 1; + Lower = (-Upper) + 1; + } + } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + const APInt &Val = CI2->getValue(); + if (Val.isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where CI2 != -1 and CI2 != 0 and CI2 != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (Val.countLeadingZeros() < Width - 1) { + // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2] + // where CI2 != -1 and CI2 != 0 and CI2 != 1 + Lower = IntMin.sdiv(Val); + Upper = IntMax.sdiv(Val); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); + } + } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) { + // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)] + Lower = CI2->getValue(); + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) { + if (CI2->isNegative()) { + // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2] + unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1; + Lower = CI2->getValue().shl(ShiftAmount); + Upper = CI2->getValue() + 1; + } else { + // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1] + unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1; + Lower = CI2->getValue(); + Upper = CI2->getValue().shl(ShiftAmount) + 1; + } + } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { + // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (CI2->getValue().ult(Width)) + Upper = NegOne.lshr(CI2->getValue()) + 1; + } else if (match(LHS, m_LShr(m_ConstantInt(CI2), m_Value()))) { + // 'lshr CI2, x' produces [CI2 >> (Width-1), CI2]. + unsigned ShiftAmount = Width - 1; + if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) + ShiftAmount = CI2->getValue().countTrailingZeros(); + Lower = CI2->getValue().lshr(ShiftAmount); + Upper = CI2->getValue() + 1; + } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) { + // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2]. + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (CI2->getValue().ult(Width)) { + Lower = IntMin.ashr(CI2->getValue()); + Upper = IntMax.ashr(CI2->getValue()) + 1; + } + } else if (match(LHS, m_AShr(m_ConstantInt(CI2), m_Value()))) { + unsigned ShiftAmount = Width - 1; + if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) + ShiftAmount = CI2->getValue().countTrailingZeros(); + if (CI2->isNegative()) { + // 'ashr CI2, x' produces [CI2, CI2 >> (Width-1)] + Lower = CI2->getValue(); + Upper = CI2->getValue().ashr(ShiftAmount) + 1; + } else { + // 'ashr CI2, x' produces [CI2 >> (Width-1), CI2] + Lower = CI2->getValue().ashr(ShiftAmount); + Upper = CI2->getValue() + 1; + } + } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) { + // 'or x, CI2' produces [CI2, UINT_MAX]. + Lower = CI2->getValue(); + } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) { + // 'and x, CI2' produces [0, CI2]. + Upper = CI2->getValue() + 1; + } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) { + // 'add nuw x, CI2' produces [CI2, UINT_MAX]. + Lower = CI2->getValue(); + } + + ConstantRange LHS_CR = + Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); + + if (auto *I = dyn_cast<Instruction>(LHS)) + if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) + LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); + + if (!LHS_CR.isFullSet()) { + if (RHS_CR.contains(LHS_CR)) + return ConstantInt::getTrue(RHS->getContext()); + if (RHS_CR.inverse().contains(LHS_CR)) + return ConstantInt::getFalse(RHS->getContext()); + } + + return nullptr; +} + /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -2290,139 +2431,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // See if we are doing a comparison with a constant integer. - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Rule out tautological comparisons (eg., ult 0 or uge 0). - ConstantRange RHS_CR = ICmpInst::makeConstantRange(Pred, CI->getValue()); - if (RHS_CR.isEmptySet()) - return ConstantInt::getFalse(CI->getContext()); - if (RHS_CR.isFullSet()) - return ConstantInt::getTrue(CI->getContext()); - - // Many binary operators with constant RHS have easy to compute constant - // range. Use them to check whether the comparison is a tautology. - unsigned Width = CI->getBitWidth(); - APInt Lower = APInt(Width, 0); - APInt Upper = APInt(Width, 0); - ConstantInt *CI2; - if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) { - // 'urem x, CI2' produces [0, CI2). - Upper = CI2->getValue(); - } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) { - // 'srem x, CI2' produces (-|CI2|, |CI2|). - Upper = CI2->getValue().abs(); - Lower = (-Upper) + 1; - } else if (match(LHS, m_UDiv(m_ConstantInt(CI2), m_Value()))) { - // 'udiv CI2, x' produces [0, CI2]. - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) { - // 'udiv x, CI2' produces [0, UINT_MAX / CI2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (!CI2->isZero()) - Upper = NegOne.udiv(CI2->getValue()) + 1; - } else if (match(LHS, m_SDiv(m_ConstantInt(CI2), m_Value()))) { - if (CI2->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = CI2->getValue(); - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv CI2, x' produces [-|CI2|, |CI2|]. - Upper = CI2->getValue().abs() + 1; - Lower = (-Upper) + 1; - } - } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - const APInt &Val = CI2->getValue(); - if (Val.isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where CI2 != -1 and CI2 != 0 and CI2 != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (Val.countLeadingZeros() < Width - 1) { - // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2] - // where CI2 != -1 and CI2 != 0 and CI2 != 1 - Lower = IntMin.sdiv(Val); - Upper = IntMax.sdiv(Val); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) { - // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)] - Lower = CI2->getValue(); - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) { - if (CI2->isNegative()) { - // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2] - unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1; - Lower = CI2->getValue().shl(ShiftAmount); - Upper = CI2->getValue() + 1; - } else { - // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1] - unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1; - Lower = CI2->getValue(); - Upper = CI2->getValue().shl(ShiftAmount) + 1; - } - } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { - // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (CI2->getValue().ult(Width)) - Upper = NegOne.lshr(CI2->getValue()) + 1; - } else if (match(LHS, m_LShr(m_ConstantInt(CI2), m_Value()))) { - // 'lshr CI2, x' produces [CI2 >> (Width-1), CI2]. - unsigned ShiftAmount = Width - 1; - if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = CI2->getValue().countTrailingZeros(); - Lower = CI2->getValue().lshr(ShiftAmount); - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) { - // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2]. - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (CI2->getValue().ult(Width)) { - Lower = IntMin.ashr(CI2->getValue()); - Upper = IntMax.ashr(CI2->getValue()) + 1; - } - } else if (match(LHS, m_AShr(m_ConstantInt(CI2), m_Value()))) { - unsigned ShiftAmount = Width - 1; - if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = CI2->getValue().countTrailingZeros(); - if (CI2->isNegative()) { - // 'ashr CI2, x' produces [CI2, CI2 >> (Width-1)] - Lower = CI2->getValue(); - Upper = CI2->getValue().ashr(ShiftAmount) + 1; - } else { - // 'ashr CI2, x' produces [CI2 >> (Width-1), CI2] - Lower = CI2->getValue().ashr(ShiftAmount); - Upper = CI2->getValue() + 1; - } - } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) { - // 'or x, CI2' produces [CI2, UINT_MAX]. - Lower = CI2->getValue(); - } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) { - // 'and x, CI2' produces [0, CI2]. - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) { - // 'add nuw x, CI2' produces [CI2, UINT_MAX]. - Lower = CI2->getValue(); - } - - ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) - : ConstantRange(Width, true); - - if (auto *I = dyn_cast<Instruction>(LHS)) - if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) - LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); - - if (!LHS_CR.isFullSet()) { - if (RHS_CR.contains(LHS_CR)) - return ConstantInt::getTrue(RHS->getContext()); - if (RHS_CR.inverse().contains(LHS_CR)) - return ConstantInt::getFalse(RHS->getContext()); - } - } + if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS)) + return V; // If both operands have range metadata, use the metadata // to simplify the comparison. |