From db9893fb90c29783d2e1fc6195dc8e9a6a149b2a Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Tue, 7 Jun 2016 10:27:15 +0000 Subject: [InstCombine][AVX2] Add support for simplifying AVX2 per-element shifts to native shifts Unlike native shifts, the AVX2 per-element shift instructions VPSRAV/VPSRLV/VPSLLV handle out of range shift values (logical shifts set the result to zero, arithmetic shifts splat the sign bit). If the shift amount is constant we can sometimes convert these instructions to native shifts: 1 - if all shift amounts are in range then the conversion is trivial. 2 - out of range arithmetic shifts can be clamped to the (bitwidth - 1) (a legal shift amount) before conversion. 3 - logical shifts just return zero if all elements have out of range shift amounts. In addition, UNDEF shift amounts are handled - either as an UNDEF shift amount in a native shift or as an UNDEF in the logical 'all out of range' zero constant special case for logical shifts. Differential Revision: http://reviews.llvm.org/D19675 llvm-svn: 271996 --- .../Transforms/InstCombine/InstCombineCalls.cpp | 125 +++++++++++++++++++++ 1 file changed, 125 insertions(+) (limited to 'llvm/lib/Transforms') diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 5de8c7a4fdb..49bc662d129 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -325,6 +325,117 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } +// Attempt to simplify AVX2 per-element shift intrinsics to a generic IR shift. +// Unlike the generic IR shifts, the intrinsics have defined behaviour for out +// of range shift amounts (logical - set to zero, arithmetic - splat sign bit). +static Value *simplifyX86varShift(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + bool LogicalShift = false; + bool ShiftLeft = false; + + switch (II.getIntrinsicID()) { + default: + return nullptr; + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + LogicalShift = false; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + LogicalShift = true; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + LogicalShift = true; + ShiftLeft = true; + break; + } + assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); + + // Simplify if all shift amounts are constant/undef. + auto *CShift = dyn_cast(II.getArgOperand(1)); + if (!CShift) + return nullptr; + + auto Vec = II.getArgOperand(0); + auto VT = cast(II.getType()); + auto SVT = VT->getVectorElementType(); + int NumElts = VT->getNumElements(); + int BitWidth = SVT->getIntegerBitWidth(); + + // Collect each element's shift amount. + // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth. + bool AnyOutOfRange = false; + SmallVector ShiftAmts; + for (int I = 0; I < NumElts; ++I) { + auto *CElt = CShift->getAggregateElement(I); + if (CElt && isa(CElt)) { + ShiftAmts.push_back(-1); + continue; + } + + auto *COp = dyn_cast_or_null(CElt); + if (!COp) + return nullptr; + + // Handle out of range shifts. + // If LogicalShift - set to BitWidth (special case). + // If ArithmeticShift - set to (BitWidth - 1) (sign splat). + APInt ShiftVal = COp->getValue(); + if (ShiftVal.uge(BitWidth)) { + AnyOutOfRange = LogicalShift; + ShiftAmts.push_back(LogicalShift ? BitWidth : BitWidth - 1); + continue; + } + + ShiftAmts.push_back((int)ShiftVal.getZExtValue()); + } + + // If all elements out of range or UNDEF, return vector of zeros/undefs. + // ArithmeticShift should only hit this if they are all UNDEF. + auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; + if (llvm::all_of(ShiftAmts, OutOfRange)) { + SmallVector ConstantVec; + for (int Idx : ShiftAmts) { + if (Idx < 0) { + ConstantVec.push_back(UndefValue::get(SVT)); + } else { + assert(LogicalShift && "Logical shift expected"); + ConstantVec.push_back(ConstantInt::getNullValue(SVT)); + } + } + return ConstantVector::get(ConstantVec); + } + + // We can't handle only some out of range values with generic logical shifts. + if (AnyOutOfRange) + return nullptr; + + // Build the shift amount constant vector. + SmallVector ShiftVecAmts; + for (int Idx : ShiftAmts) { + if (Idx < 0) + ShiftVecAmts.push_back(UndefValue::get(SVT)); + else + ShiftVecAmts.push_back(ConstantInt::get(SVT, Idx)); + } + auto ShiftVec = ConstantVector::get(ShiftVecAmts); + + if (ShiftLeft) + return Builder.CreateShl(Vec, ShiftVec); + + if (LogicalShift) + return Builder.CreateLShr(Vec, ShiftVec); + + return Builder.CreateAShr(Vec, ShiftVec); +} + static Value *simplifyX86movmsk(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { Value *Arg = II.getArgOperand(0); @@ -1656,6 +1767,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + if (Value *V = simplifyX86varShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::x86_sse41_insertps: if (Value *V = simplifyX86insertps(*II, *Builder)) return replaceInstUsesWith(*II, V); -- cgit v1.2.3