diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index a366db4d1e3..36180c00348 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -898,6 +898,130 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } +/// General pattern: +/// X & Y +/// +/// Where Y is checking that all the high bits (covered by a mask 4294967168) +/// are uniform, i.e. %arg & 4294967168 can be either 4294967168 or 0 +/// Pattern can be one of: +/// %t = add i32 %arg, 128 +/// %r = icmp ult i32 %t, 256 +/// Or +/// %t0 = shl i32 %arg, 24 +/// %t1 = ashr i32 %t0, 24 +/// %r = icmp eq i32 %t1, %arg +/// Or +/// %t0 = trunc i32 %arg to i8 +/// %t1 = sext i8 %t0 to i32 +/// %r = icmp eq i32 %t1, %arg +/// This pattern is a signed truncation check. +/// +/// And X is checking that some bit in that same mask is zero. +/// I.e. can be one of: +/// %r = icmp sgt i32 %arg, -1 +/// Or +/// %t = and i32 %arg, 2147483648 +/// %r = icmp eq i32 %t, 0 +/// +/// Since we are checking that all the bits in that mask are the same, +/// and a particular bit is zero, what we are really checking is that all the +/// masked bits are zero. +/// So this should be transformed to: +/// %r = icmp ult i32 %arg, 128 +static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, + Instruction &CxtI, + InstCombiner::BuilderTy &Builder) { + assert(CxtI.getOpcode() == Instruction::And); + + // Match icmp ult (add %arg, C01), C1 (C1 == C01 << 1; powers of two) + auto tryToMatchSignedTruncationCheck = [](ICmpInst *ICmp, Value *&X, + APInt &SignBitMask) -> bool { + CmpInst::Predicate Pred; + const APInt *I01, *I1; // powers of two; I1 == I01 << 1 + if (!(match(ICmp, + m_ICmp(Pred, m_Add(m_Value(X), m_Power2(I01)), m_Power2(I1))) && + Pred == ICmpInst::ICMP_ULT && I1->ugt(*I01) && I01->shl(1) == *I1)) + return false; + // Which bit is the new sign bit as per the 'signed truncation' pattern? + SignBitMask = *I01; + return true; + }; + + // One icmp needs to be 'signed truncation check'. + // We need to match this first, else we will mismatch commutative cases. + Value *X1; + APInt HighestBit; + ICmpInst *OtherICmp; + if (tryToMatchSignedTruncationCheck(ICmp1, X1, HighestBit)) + OtherICmp = ICmp0; + else if (tryToMatchSignedTruncationCheck(ICmp0, X1, HighestBit)) + OtherICmp = ICmp1; + else + return nullptr; + + assert(HighestBit.isPowerOf2() && "expected to be power of two (non-zero)"); + + // Try to match/decompose into: icmp eq (X & Mask), 0 + auto tryToDecompose = [](ICmpInst *ICmp, Value *&X, + APInt &UnsetBitsMask) -> bool { + CmpInst::Predicate Pred = ICmp->getPredicate(); + // Can it be decomposed into icmp eq (X & Mask), 0 ? + if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), + Pred, X, UnsetBitsMask, + /*LookThruTrunc=*/false) && + Pred == ICmpInst::ICMP_EQ) + return true; + // Is it icmp eq (X & Mask), 0 already? + const APInt *Mask; + if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && + Pred == ICmpInst::ICMP_EQ) { + UnsetBitsMask = *Mask; + return true; + } + return false; + }; + + // And the other icmp needs to be decomposable into a bit test. + Value *X0; + APInt UnsetBitsMask; + if (!tryToDecompose(OtherICmp, X0, UnsetBitsMask)) + return nullptr; + + assert(!UnsetBitsMask.isNullValue() && "empty mask makes no sense."); + + // Are they working on the same value? + Value *X; + if (X1 == X0) { + // Ok as is. + X = X1; + } else if (match(X0, m_Trunc(m_Specific(X1)))) { + UnsetBitsMask = UnsetBitsMask.zext(X1->getType()->getScalarSizeInBits()); + X = X1; + } else + return nullptr; + + // So which bits should be uniform as per the 'signed truncation check'? + // (all the bits starting with (i.e. including) HighestBit) + APInt SignBitsMask = ~(HighestBit - 1U); + + // UnsetBitsMask must have some common bits with SignBitsMask, + if (!UnsetBitsMask.intersects(SignBitsMask)) + return nullptr; + + // Does UnsetBitsMask contain any bits outside of SignBitsMask? + if (!UnsetBitsMask.isSubsetOf(SignBitsMask)) { + APInt OtherHighestBit = (~UnsetBitsMask) + 1U; + if (!OtherHighestBit.isPowerOf2()) + return nullptr; + HighestBit = APIntOps::umin(HighestBit, OtherHighestBit); + } + // Else, if it does not, then all is ok as-is. + + // %r = icmp ult %X, SignBit + return Builder.CreateICmpULT(X, ConstantInt::get(X->getType(), HighestBit), + CxtI.getName() + ".simplified"); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { @@ -937,6 +1061,9 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) return V; + if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); |