diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 55 |
1 files changed, 39 insertions, 16 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 668789b516d..de413c42348 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -200,33 +200,56 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { static Value *SimplifyX86immshift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder, bool ShiftLeft) { - // Simplify if count is constant. To 0 if >= BitWidth, - // otherwise to shl/lshr. - auto CDV = dyn_cast<ConstantDataVector>(II.getArgOperand(1)); - auto CInt = dyn_cast<ConstantInt>(II.getArgOperand(1)); - if (!CDV && !CInt) + // Simplify if count is constant. + auto Arg1 = II.getArgOperand(1); + auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1); + auto CDV = dyn_cast<ConstantDataVector>(Arg1); + auto CInt = dyn_cast<ConstantInt>(Arg1); + if (!CAZ && !CDV && !CInt) return nullptr; - ConstantInt *Count; - if (CDV) - Count = cast<ConstantInt>(CDV->getElementAsConstant(0)); - else - Count = CInt; + + APInt Count(64, 0); + if (CDV) { + // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector + // operand to compute the shift amount. + auto VT = cast<VectorType>(CDV->getType()); + unsigned BitWidth = VT->getElementType()->getPrimitiveSizeInBits(); + assert((64 % BitWidth) == 0 && "Unexpected packed shift size"); + unsigned NumSubElts = 64 / BitWidth; + + // Concatenate the sub-elements to create the 64-bit value. + for (unsigned i = 0; i != NumSubElts; ++i) { + unsigned SubEltIdx = (NumSubElts - 1) - i; + auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx)); + Count = Count.shl(BitWidth); + Count |= SubElt->getValue().zextOrTrunc(64); + } + } + else if (CInt) + Count = CInt->getValue(); auto Vec = II.getArgOperand(0); auto VT = cast<VectorType>(Vec->getType()); auto SVT = VT->getElementType(); - if (Count->getZExtValue() > (SVT->getPrimitiveSizeInBits() - 1)) - return ConstantAggregateZero::get(VT); - unsigned VWidth = VT->getNumElements(); + unsigned BitWidth = SVT->getPrimitiveSizeInBits(); + + // If shift-by-zero then just return the original value. + if (Count == 0) + return Vec; + + // Handle cases when Shift >= BitWidth - just return zero. + if (Count.uge(BitWidth)) + return ConstantAggregateZero::get(VT); // Get a constant vector of the same type as the first operand. - auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue()); + auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth)); + auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt); if (ShiftLeft) - return Builder.CreateShl(Vec, Builder.CreateVectorSplat(VWidth, VTCI)); + return Builder.CreateShl(Vec, ShiftVec); - return Builder.CreateLShr(Vec, Builder.CreateVectorSplat(VWidth, VTCI)); + return Builder.CreateLShr(Vec, ShiftVec); } static Value *SimplifyX86extend(const IntrinsicInst &II, |