diff options
Diffstat (limited to 'llvm')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 24 | ||||
-rw-r--r-- | llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll | 7 |
2 files changed, 27 insertions, 4 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index fbff5dd4a8c..739579e2d38 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -23,8 +23,11 @@ using namespace PatternMatch; // Given pattern: // (x shiftopcode Q) shiftopcode K // we should rewrite it as -// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) -// This is valid for any shift, but they must be identical. +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and +// +// This is valid for any shift, but they must be identical, and we must be +// careful in case we have (zext(Q)+zext(K)) and look past extensions, +// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus. // // AnalyzeForSignBitExtraction indicates that we will only analyze whether this // pattern has any 2 right-shifts that sum to 1 less than original bit width. @@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( if (ShAmt0->getType() != ShAmt1->getType()) return nullptr; + // As input, we have the following pattern: + // Sh0 (Sh1 X, Q), K + // We want to rewrite that as: + // Sh x, (Q+K) iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (Sh0->getType()->getScalarSizeInBits() - 1) + + (Sh1->getType()->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // We are only looking for signbit extraction if we have two right shifts. bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && match(Sh1, m_Shr(m_Value(), m_Value())); diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll index 0b8187d0417..96461691e70 100644 --- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll +++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll @@ -320,12 +320,15 @@ define i32 @n20(i32 %x, i32 %y) { ret i32 %t3 } -; FIXME: this is a miscompile. We should not transform this. ; See https://bugs.llvm.org/show_bug.cgi?id=44802 define i3 @pr44802(i3 %t0) { ; CHECK-LABEL: @pr44802( ; CHECK-NEXT: [[T1:%.*]] = sub i3 0, [[T0:%.*]] -; CHECK-NEXT: ret i3 [[T1]] +; CHECK-NEXT: [[T2:%.*]] = icmp ne i3 [[T0]], 0 +; CHECK-NEXT: [[T3:%.*]] = zext i1 [[T2]] to i3 +; CHECK-NEXT: [[T4:%.*]] = lshr i3 [[T1]], [[T3]] +; CHECK-NEXT: [[T5:%.*]] = lshr i3 [[T4]], [[T3]] +; CHECK-NEXT: ret i3 [[T5]] ; %t1 = sub i3 0, %t0 %t2 = icmp ne i3 %t0, 0 |