summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp29
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h4
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp53
-rw-r--r--llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll12
4 files changed, 66 insertions, 32 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index ee51bc03312..601295339ad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1358,19 +1358,28 @@ Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
/// Fold equality-comparison between zero and any (maybe truncated) right-shift
/// by one-less-than-bitwidth into a sign test on the original value.
-Instruction *foldSignBitTest(ICmpInst &I) {
+Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) {
+ Instruction *Val;
ICmpInst::Predicate Pred;
- Value *X;
- Constant *C;
- if (!I.isEquality() ||
- !match(&I, m_ICmp(Pred, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))),
- m_Zero())))
+ if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
return nullptr;
- Type *XTy = X->getType();
- unsigned XBitWidth = XTy->getScalarSizeInBits();
- if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
- APInt(XBitWidth, XBitWidth - 1))))
+ Value *X;
+ Type *XTy;
+
+ Constant *C;
+ if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) {
+ XTy = X->getType();
+ unsigned XBitWidth = XTy->getScalarSizeInBits();
+ if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(XBitWidth, XBitWidth - 1))))
+ return nullptr;
+ } else if (isa<BinaryOperator>(Val) &&
+ (X = reassociateShiftAmtsOfTwoSameDirectionShifts(
+ cast<BinaryOperator>(Val), SQ.getWithInstruction(Val),
+ /*AnalyzeForSignBitExtraction=*/true))) {
+ XTy = X->getType();
+ } else
return nullptr;
return ICmpInst::Create(Instruction::ICmp,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index e04cd346b6f..4519dc0bf37 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -390,6 +390,9 @@ public:
Instruction *visitOr(BinaryOperator &I);
Instruction *visitXor(BinaryOperator &I);
Instruction *visitShl(BinaryOperator &I);
+ Value *reassociateShiftAmtsOfTwoSameDirectionShifts(
+ BinaryOperator *Sh0, const SimplifyQuery &SQ,
+ bool AnalyzeForSignBitExtraction = false);
Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract(
BinaryOperator &OldAShr);
Instruction *visitAShr(BinaryOperator &I);
@@ -912,6 +915,7 @@ private:
Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
Instruction *foldICmpEquality(ICmpInst &Cmp);
Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I);
+ Instruction *foldSignBitTest(ICmpInst &I);
Instruction *foldICmpWithZero(ICmpInst &Cmp);
Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index cc0e35e4a9c..64294838644 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -25,10 +25,12 @@ using namespace PatternMatch;
// we should rewrite it as
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
// This is valid for any shift, but they must be identical.
-static Instruction *
-reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
- const SimplifyQuery &SQ,
- InstCombiner::BuilderTy &Builder) {
+//
+// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
+// pattern has any 2 right-shifts that sum to 1 less than original bit width.
+Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
+ BinaryOperator *Sh0, const SimplifyQuery &SQ,
+ bool AnalyzeForSignBitExtraction) {
// Look for a shift of some instruction, ignore zext of shift amount if any.
Instruction *Sh0Op0;
Value *ShAmt0;
@@ -56,14 +58,25 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
if (ShAmt0->getType() != ShAmt1->getType())
return nullptr;
- // The shift opcodes must be identical.
+ // We are only looking for signbit extraction if we have two right shifts.
+ bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
+ match(Sh1, m_Shr(m_Value(), m_Value()));
+ // ... and if it's not two right-shifts, we know the answer already.
+ if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
+ return nullptr;
+
+ // The shift opcodes must be identical, unless we are just checking whether
+ // this pattern can be interpreted as a sign-bit-extraction.
Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
- if (ShiftOpcode != Sh1->getOpcode())
+ bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
+ if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
return nullptr;
// If we saw truncation, we'll need to produce extra instruction,
- // and for that one of the operands of the shift must be one-use.
- if (Trunc && !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ // and for that one of the operands of the shift must be one-use,
+ // unless of course we don't actually plan to produce any instructions here.
+ if (Trunc && !AnalyzeForSignBitExtraction &&
+ !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
return nullptr;
// Can we fold (ShAmt0+ShAmt1) ?
@@ -80,14 +93,22 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
return nullptr; // FIXME: could perform constant-folding.
// If there was a truncation, and we have a right-shift, we can only fold if
- // we are left with the original sign bit.
+ // we are left with the original sign bit. Likewise, if we were just checking
+ // that this is a sighbit extraction, this is the place to check it.
// FIXME: zero shift amount is also legal here, but we can't *easily* check
// more than one predicate so it's not really worth it.
- if (Trunc && ShiftOpcode != Instruction::BinaryOps::Shl &&
- !match(NewShAmt,
- m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
- APInt(NewShAmtBitWidth, XBitWidth - 1))))
- return nullptr;
+ if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
+ // If it's not a sign bit extraction, then we're done.
+ if (!match(NewShAmt,
+ m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(NewShAmtBitWidth, XBitWidth - 1))))
+ return nullptr;
+ // If it is, and that was the question, return the base value.
+ if (AnalyzeForSignBitExtraction)
+ return X;
+ }
+
+ assert(IdenticalShOpcodes && "Should not get here with different shifts.");
// All good, we can do this fold.
NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
@@ -287,8 +308,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
return Res;
- if (Instruction *NewShift =
- reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ, Builder))
+ if (auto *NewShift = cast_or_null<Instruction>(
+ reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
return NewShift;
// (C1 shift (A add C2)) -> (C1 shift C2) shift A)
diff --git a/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll b/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll
index 0e5848e7303..8e89a0649eb 100644
--- a/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll
+++ b/llvm/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll
@@ -45,7 +45,7 @@ define i1 @highest_bit_test_via_lshr_with_truncation(i64 %data, i32 %nbits) {
; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
; CHECK-NEXT: ret i1 [[ISNEG]]
;
%num_low_bits_to_skip = sub i32 64, %nbits
@@ -107,7 +107,7 @@ define i1 @highest_bit_test_via_ashr_with_truncation(i64 %data, i32 %nbits) {
; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
; CHECK-NEXT: ret i1 [[ISNEG]]
;
%num_low_bits_to_skip = sub i32 64, %nbits
@@ -138,7 +138,7 @@ define i1 @highest_bit_test_via_lshr_ashr(i32 %data, i32 %nbits) {
; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED]])
; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0
; CHECK-NEXT: ret i1 [[ISNEG]]
;
%num_low_bits_to_skip = sub i32 32, %nbits
@@ -169,7 +169,7 @@ define i1 @highest_bit_test_via_lshr_ashe_with_truncation(i64 %data, i32 %nbits)
; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
; CHECK-NEXT: ret i1 [[ISNEG]]
;
%num_low_bits_to_skip = sub i32 64, %nbits
@@ -200,7 +200,7 @@ define i1 @highest_bit_test_via_ashr_lshr(i32 %data, i32 %nbits) {
; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED]])
; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0
; CHECK-NEXT: ret i1 [[ISNEG]]
;
%num_low_bits_to_skip = sub i32 32, %nbits
@@ -231,7 +231,7 @@ define i1 @highest_bit_test_via_ashr_lshr_with_truncation(i64 %data, i32 %nbits)
; CHECK-NEXT: call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
; CHECK-NEXT: call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
; CHECK-NEXT: call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT: [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT: [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
; CHECK-NEXT: ret i1 [[ISNEG]]
;
%num_low_bits_to_skip = sub i32 64, %nbits
OpenPOWER on IntegriCloud