diff options
-rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 212 |
1 files changed, 115 insertions, 97 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 0cb2c78afb4..9eefc9980fe 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -3371,6 +3371,118 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return nullptr; } +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison where one operand of the compare is a constant. +static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, + const APInt *Y, bool TrueWhenUnset) { + const APInt *C; + + // (X & Y) == 0 ? X & ~Y : X --> X + // (X & Y) != 0 ? X & ~Y : X --> X & ~Y + if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + // (X & Y) == 0 ? X : X & ~Y --> X & ~Y + // (X & Y) != 0 ? X : X & ~Y --> X + if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + if (Y->isPowerOf2()) { + // (X & Y) == 0 ? X | Y : X --> X | Y + // (X & Y) != 0 ? X | Y : X --> X + if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + + // (X & Y) == 0 ? X : X | Y --> X + // (X & Y) != 0 ? X : X | Y --> X | Y + if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + } + + return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison. +static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, + Value *FalseVal, const Query &Q, + unsigned MaxRecurse) { + ICmpInst::Predicate Pred; + Value *CmpLHS, *CmpRHS; + if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) + return nullptr; + + unsigned BitWidth = Q.DL.getTypeSizeInBits(TrueVal->getType()); + APInt MinSignedValue = APInt::getSignBit(BitWidth); + if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) { + Value *X; + const APInt *Y; + if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y)))) + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, + Pred == ICmpInst::ICMP_EQ)) + return V; + } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, CmpLHS, + &MinSignedValue, false)) + return V; + } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, CmpLHS, + &MinSignedValue, true)) + return V; + } + + if (CondVal->hasOneUse()) { + const APInt *C; + if (match(CmpRHS, m_APInt(C))) { + // X < MIN ? T : F --> F + if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) + return FalseVal; + // X < MIN ? T : F --> F + if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) + return FalseVal; + // X > MAX ? T : F --> F + if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) + return FalseVal; + // X > MAX ? T : F --> F + if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) + return FalseVal; + } + } + + // If we have an equality comparison, then we know the value in one of the + // arms of the select. See if substituting this value into the arm and + // simplifying the result yields the same value as the other arm. + if (Pred == ICmpInst::ICMP_EQ) { + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return FalseVal; + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return FalseVal; + } else if (Pred == ICmpInst::ICMP_NE) { + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return TrueVal; + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return TrueVal; + } + + return nullptr; +} + /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, @@ -3399,103 +3511,9 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, if (isa<UndefValue>(FalseVal)) // select C, X, undef -> X return TrueVal; - if (const auto *ICI = dyn_cast<ICmpInst>(CondVal)) { - unsigned BitWidth = Q.DL.getTypeSizeInBits(TrueVal->getType()); - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - APInt MinSignedValue = APInt::getSignBit(BitWidth); - Value *X; - const APInt *Y; - bool TrueWhenUnset; - bool IsBitTest = false; - if (ICmpInst::isEquality(Pred) && - match(CmpLHS, m_And(m_Value(X), m_APInt(Y))) && - match(CmpRHS, m_Zero())) { - IsBitTest = true; - TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; - } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { - X = CmpLHS; - Y = &MinSignedValue; - IsBitTest = true; - TrueWhenUnset = false; - } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { - X = CmpLHS; - Y = &MinSignedValue; - IsBitTest = true; - TrueWhenUnset = true; - } - if (IsBitTest) { - const APInt *C; - // (X & Y) == 0 ? X & ~Y : X --> X - // (X & Y) != 0 ? X & ~Y : X --> X & ~Y - if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) - return TrueWhenUnset ? FalseVal : TrueVal; - // (X & Y) == 0 ? X : X & ~Y --> X & ~Y - // (X & Y) != 0 ? X : X & ~Y --> X - if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) - return TrueWhenUnset ? FalseVal : TrueVal; - - if (Y->isPowerOf2()) { - // (X & Y) == 0 ? X | Y : X --> X | Y - // (X & Y) != 0 ? X | Y : X --> X - if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) - return TrueWhenUnset ? TrueVal : FalseVal; - // (X & Y) == 0 ? X : X | Y --> X - // (X & Y) != 0 ? X : X | Y --> X | Y - if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) - return TrueWhenUnset ? TrueVal : FalseVal; - } - } - if (ICI->hasOneUse()) { - const APInt *C; - if (match(CmpRHS, m_APInt(C))) { - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) - return FalseVal; - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) - return FalseVal; - } - } - - // If we have an equality comparison then we know the value in one of the - // arms of the select. See if substituting this value into the arm and - // simplifying the result yields the same value as the other arm. - if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return FalseVal; - } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return TrueVal; - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return TrueVal; - } - } + if (Value *V = + simplifySelectWithICmpCond(CondVal, TrueVal, FalseVal, Q, MaxRecurse)) + return V; return nullptr; } |