diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 42 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 2 |
2 files changed, 44 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 7dd21edd07f..e23e85bb16a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2249,6 +2249,44 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, return nullptr; } +Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp, + BinaryOperator *SRem, + const APInt &C) { + // Match an 'is positive' or 'is negative' comparison of remainder by a + // constant power-of-2 value: + // (X % pow2C) sgt/slt 0 + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT) + return nullptr; + + // TODO: The one-use check is standard because we do not typically want to + // create longer instruction sequences, but this might be a special-case + // because srem is not good for analysis or codegen. + if (!SRem->hasOneUse()) + return nullptr; + + const APInt *DivisorC; + if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + return nullptr; + + // Mask off the sign bit and the modulo bits (low-bits). + Type *Ty = SRem->getType(); + APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits()); + Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); + Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); + + // For 'is positive?' check that the sign-bit is clear and at least 1 masked + // bit is set. Example: + // (i8 X % 32) s> 0 --> (X & 159) s> 0 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty)); + + // For 'is negative?' check that the sign-bit is set and at least 1 masked + // bit is set. Example: + // (i16 X % 4) s< 0 --> (X & 32771) u> 32768 + return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask)); +} + /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, @@ -2806,6 +2844,10 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; + case Instruction::SRem: + if (Instruction *I = foldICmpSRemConstant(Cmp, BO, *C)) + return I; + break; case Instruction::UDiv: if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 91f5228b370..5e4a56ba257 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -891,6 +891,8 @@ private: const APInt &C); Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, const APInt &C); + Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt &C); Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, const APInt &C); Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, |