diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2015-08-07 18:22:50 +0000 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2015-08-07 18:22:50 +0000 |
commit | 3815c16bf86c44afa3bf4bd255b730330c218c33 (patch) | |
tree | 847179c61aca11a0f9849df04a579dc80d81b6be /llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | |
parent | 855ea0f71af3c0646faa436508daadb47834d716 (diff) | |
download | bcm5719-llvm-3815c16bf86c44afa3bf4bd255b730330c218c33.tar.gz bcm5719-llvm-3815c16bf86c44afa3bf4bd255b730330c218c33.zip |
[InstCombine] Fix SSE2/AVX2 vector logical shift by constant
This patch fixes the sse2/avx2 vector shift by constant instcombine call to correctly deal with the fact that the shift amount is formed from the entire lower 64-bit and not just the lowest element as it currently assumes.
e.g.
%1 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %v, <4 x i32> <i32 15, i32 15, i32 15, i32 15>)
In this case, (V)PSRLD doesn't perform a lshr by 15 but in fact attempts to shift by 64424509455 ((15 << 32) | 15) - giving a zero result.
In addition, this review also recognizes shift-by-zero from a ConstantAggregateZero type (PR23821).
Differential Revision: http://reviews.llvm.org/D11760
llvm-svn: 244341
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 55 |
1 files changed, 39 insertions, 16 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 668789b516d..de413c42348 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -200,33 +200,56 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { static Value *SimplifyX86immshift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder, bool ShiftLeft) { - // Simplify if count is constant. To 0 if >= BitWidth, - // otherwise to shl/lshr. - auto CDV = dyn_cast<ConstantDataVector>(II.getArgOperand(1)); - auto CInt = dyn_cast<ConstantInt>(II.getArgOperand(1)); - if (!CDV && !CInt) + // Simplify if count is constant. + auto Arg1 = II.getArgOperand(1); + auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1); + auto CDV = dyn_cast<ConstantDataVector>(Arg1); + auto CInt = dyn_cast<ConstantInt>(Arg1); + if (!CAZ && !CDV && !CInt) return nullptr; - ConstantInt *Count; - if (CDV) - Count = cast<ConstantInt>(CDV->getElementAsConstant(0)); - else - Count = CInt; + + APInt Count(64, 0); + if (CDV) { + // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector + // operand to compute the shift amount. + auto VT = cast<VectorType>(CDV->getType()); + unsigned BitWidth = VT->getElementType()->getPrimitiveSizeInBits(); + assert((64 % BitWidth) == 0 && "Unexpected packed shift size"); + unsigned NumSubElts = 64 / BitWidth; + + // Concatenate the sub-elements to create the 64-bit value. + for (unsigned i = 0; i != NumSubElts; ++i) { + unsigned SubEltIdx = (NumSubElts - 1) - i; + auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx)); + Count = Count.shl(BitWidth); + Count |= SubElt->getValue().zextOrTrunc(64); + } + } + else if (CInt) + Count = CInt->getValue(); auto Vec = II.getArgOperand(0); auto VT = cast<VectorType>(Vec->getType()); auto SVT = VT->getElementType(); - if (Count->getZExtValue() > (SVT->getPrimitiveSizeInBits() - 1)) - return ConstantAggregateZero::get(VT); - unsigned VWidth = VT->getNumElements(); + unsigned BitWidth = SVT->getPrimitiveSizeInBits(); + + // If shift-by-zero then just return the original value. + if (Count == 0) + return Vec; + + // Handle cases when Shift >= BitWidth - just return zero. + if (Count.uge(BitWidth)) + return ConstantAggregateZero::get(VT); // Get a constant vector of the same type as the first operand. - auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue()); + auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth)); + auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt); if (ShiftLeft) - return Builder.CreateShl(Vec, Builder.CreateVectorSplat(VWidth, VTCI)); + return Builder.CreateShl(Vec, ShiftVec); - return Builder.CreateLShr(Vec, Builder.CreateVectorSplat(VWidth, VTCI)); + return Builder.CreateLShr(Vec, ShiftVec); } static Value *SimplifyX86extend(const IntrinsicInst &II, |