diff options
author | Ehsan Amiri <amehsan@ca.ibm.com> | 2016-12-15 12:25:13 +0000 |
---|---|---|
committer | Ehsan Amiri <amehsan@ca.ibm.com> | 2016-12-15 12:25:13 +0000 |
commit | 795b0671c5fd3c064f9502d388e5f40a196b9d56 (patch) | |
tree | 3919e2c326458cea442b821277c0034fe1764270 /llvm/lib/Transforms | |
parent | 3da2619b6f2827de20b2a727e06455fb4ee9c3fc (diff) | |
download | bcm5719-llvm-795b0671c5fd3c064f9502d388e5f40a196b9d56.tar.gz bcm5719-llvm-795b0671c5fd3c064f9502d388e5f40a196b9d56.zip |
[InstCombine] New opportunities for FoldAndOfICmp and FoldXorOfICmp
A number of new patterns for simplifying and/xor of icmp:
(icmp ne %x, 0) ^ (icmp ne %y, 0) => icmp ne %x, %y if the following is true:
1- (%x = and %a, %mask) and (%y = and %b, %mask)
2- %mask is a power of 2.
(icmp eq %x, 0) & (icmp ne %y, 0) => icmp ult %x, %y if the following is true:
1- (%x = and %a, %mask1) and (%y = and %b, %mask2)
2- Let %t be the smallest power of 2 where %mask1 & %t != 0. Then for any
%s that is a power of 2 and %s & %mask2 != 0, we must have %s <= %t.
For example if %mask1 = 24 and %mask2 = 16, setting %s = 16 and %t = 8
violates condition (2) above. So this optimization cannot be applied.
llvm-svn: 289813
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 99 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 1 |
2 files changed, 98 insertions, 2 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index d4bd78bc805..e1e060b283e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -733,6 +733,44 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, return nullptr; } +namespace { + +struct BitGroupCheck { + // If the Cmp, checks the bits in the group are nonzero? + bool CheckIfSet {false}; + // The mask that identifies the bitgroup in question. + const APInt *Mask {nullptr}; +}; +} +/// For an ICMP where RHS is zero, we want to check if the ICMP is equivalent to +/// comparing a group of bits in an integer value against zero. +BitGroupCheck isAnyBitSet(Value *LHS, ICmpInst::Predicate CC) { + + BitGroupCheck BGC; + auto *Inst = dyn_cast<Instruction>(LHS); + + if (!Inst || Inst->getOpcode() != Instruction::And) + return BGC; + + // TODO Currently this does not work for vectors. + ConstantInt *Mask; + if (!match(LHS, m_And(m_Value(), m_ConstantInt(Mask)))) + return BGC; + // At this point we know that LHS of ICMP is "and" of a value with a constant. + // Also we know that the RHS is zero. That means we are checking if a certain + // group of bits in a given integer value are all zero or at least one of them + // is set to one. + if (CC == ICmpInst::ICMP_EQ) + BGC.CheckIfSet = false; + else if (CC == ICmpInst::ICMP_NE) + BGC.CheckIfSet = true; + else + return BGC; + + BGC.Mask = &Mask->getValue(); + return BGC; +} + /// Try to fold a signed range checked with lower bound 0 to an unsigned icmp. /// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n /// If \p Inverted is true then the check is for the inverted range, e.g. @@ -789,6 +827,32 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder->CreateICmp(NewPred, Input, RangeEnd); } +Value *InstCombiner::FoldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { + + Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); + // TODO The lines below does not work for vectors. ConstantInt is scalar. + auto *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); + auto *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); + if (!LHSCst || !RHSCst) + return nullptr; + ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + + // E.g. (icmp ne %x, 0) ^ (icmp ne %y, 0) => icmp ne %x, %y if the following + // conditions hold: + // 1- (%x = and %a, %mask) and (%y = and %b, %mask) + // 2- %mask is a power of 2. + if (RHSCst->isZero() && LHSCst == RHSCst) { + + BitGroupCheck BGC1 = isAnyBitSet(Val, LHSCC); + BitGroupCheck BGC2 = isAnyBitSet(Val2, RHSCC); + if (BGC1.Mask && BGC2.Mask && BGC1.CheckIfSet == BGC2.CheckIfSet && + *BGC1.Mask == *BGC2.Mask && BGC1.Mask->isPowerOf2()) { + return Builder->CreateICmp(ICmpInst::ICMP_NE, Val2, Val); + } + } + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); @@ -871,6 +935,29 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } + // E.g. (icmp eq %x, 0) & (icmp ne %y, 0) => icmp ult %x, %y if the following + // conditions hold: + // 1- (%x = and %a, %mask1) and (%y = and %b, %mask2) + // 2- Let %t be the smallest power of 2 where %mask1 & %t != 0. Then for any + // %s that is a power of 2 and %s & %mask2 != 0, we must have %s <= %t. + // For example if %mask1 = 24 and %mask2 = 16, setting %s = 16 and %t = 8 + // violates condition (2) above. So this optimization cannot be applied. + if (RHSCst->isZero() && LHSCst == RHSCst) { + BitGroupCheck BGC1 = isAnyBitSet(Val, LHSCC); + BitGroupCheck BGC2 = isAnyBitSet(Val2, RHSCC); + + if (BGC1.Mask && BGC2.Mask && (BGC1.CheckIfSet != BGC2.CheckIfSet)) { + if (!BGC1.CheckIfSet && + BGC1.Mask->countTrailingZeros() >= + BGC2.Mask->getBitWidth() - BGC2.Mask->countLeadingZeros() - 1) + return Builder->CreateICmp(ICmpInst::ICMP_ULT, Val, Val2); + else if (!BGC2.CheckIfSet && + BGC2.Mask->countTrailingZeros() >= + BGC1.Mask->getBitWidth() - BGC1.Mask->countLeadingZeros() - 1) + return Builder->CreateICmp(ICmpInst::ICMP_ULT, Val2, Val); + } + } + // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. if (Val != Val2) return nullptr; @@ -2704,9 +2791,16 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); - // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) + if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) { + + // E.g. if we have xor (icmp eq %A, 0), (icmp eq %B, 0) + // and we know both A and B are either 8 (power of 2) or 0 + // we can simplify to (icmp ne A, B). + if (Value *Res = FoldXorOfICmps(LHS, RHS)) + return replaceInstUsesWith(I, Res); + + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) @@ -2721,6 +2815,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Builder)); } } + } if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 8b71352440b..24ba412ca99 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -239,6 +239,7 @@ public: Instruction *visitFDiv(BinaryOperator &I); Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *FoldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); |