diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index dbc2175196a..743e20b294f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1612,6 +1612,61 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } + // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) + // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) + // The original condition actually refers to the following two ranges: + // [MAX_UINT-C1+1, MAX_UINT-C1+1+C3] and [MAX_UINT-C2+1, MAX_UINT-C2+1+C3] + // We can fold these two ranges if: + // 1) C1 and C2 is unsigned greater than C3. + // 2) The two ranges are separated. + // 3) C1 ^ C2 is one-bit mask. + // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. + // This implies all values in the two ranges differ by exactly one bit. + + if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && + LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && + RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && + LHSCst->getValue() == (RHSCst->getValue())) { + + Value *LAdd = LHS->getOperand(0); + Value *RAdd = RHS->getOperand(0); + + Value *LAddOpnd, *RAddOpnd; + ConstantInt *LAddCst, *RAddCst; + if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && + match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && + LAddCst->getValue().ugt(LHSCst->getValue()) && + RAddCst->getValue().ugt(LHSCst->getValue())) { + + APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); + if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { + ConstantInt *MaxAddCst = nullptr; + if (LAddCst->getValue().ult(RAddCst->getValue())) + MaxAddCst = RAddCst; + else + MaxAddCst = LAddCst; + + APInt RRangeLow = -RAddCst->getValue(); + APInt RRangeHigh = RRangeLow + LHSCst->getValue(); + APInt LRangeLow = -LAddCst->getValue(); + APInt LRangeHigh = LRangeLow + LHSCst->getValue(); + APInt LowRangeDiff = RRangeLow ^ LRangeLow; + APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; + APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow + : RRangeLow - LRangeLow; + + if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && + RangeDiff.ugt(LHSCst->getValue())) { + Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); + + Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); + Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); + return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); + } + } + } + } + // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) if (PredicatesFoldable(LHSCC, RHSCC)) { if (LHS->getOperand(0) == RHS->getOperand(1) && |

