diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 5de8c7a4fdb..49bc662d129 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -325,6 +325,117 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } +// Attempt to simplify AVX2 per-element shift intrinsics to a generic IR shift. +// Unlike the generic IR shifts, the intrinsics have defined behaviour for out +// of range shift amounts (logical - set to zero, arithmetic - splat sign bit). +static Value *simplifyX86varShift(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + bool LogicalShift = false; + bool ShiftLeft = false; + + switch (II.getIntrinsicID()) { + default: + return nullptr; + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + LogicalShift = false; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + LogicalShift = true; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + LogicalShift = true; + ShiftLeft = true; + break; + } + assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); + + // Simplify if all shift amounts are constant/undef. + auto *CShift = dyn_cast<Constant>(II.getArgOperand(1)); + if (!CShift) + return nullptr; + + auto Vec = II.getArgOperand(0); + auto VT = cast<VectorType>(II.getType()); + auto SVT = VT->getVectorElementType(); + int NumElts = VT->getNumElements(); + int BitWidth = SVT->getIntegerBitWidth(); + + // Collect each element's shift amount. + // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth. + bool AnyOutOfRange = false; + SmallVector<int, 8> ShiftAmts; + for (int I = 0; I < NumElts; ++I) { + auto *CElt = CShift->getAggregateElement(I); + if (CElt && isa<UndefValue>(CElt)) { + ShiftAmts.push_back(-1); + continue; + } + + auto *COp = dyn_cast_or_null<ConstantInt>(CElt); + if (!COp) + return nullptr; + + // Handle out of range shifts. + // If LogicalShift - set to BitWidth (special case). + // If ArithmeticShift - set to (BitWidth - 1) (sign splat). + APInt ShiftVal = COp->getValue(); + if (ShiftVal.uge(BitWidth)) { + AnyOutOfRange = LogicalShift; + ShiftAmts.push_back(LogicalShift ? BitWidth : BitWidth - 1); + continue; + } + + ShiftAmts.push_back((int)ShiftVal.getZExtValue()); + } + + // If all elements out of range or UNDEF, return vector of zeros/undefs. + // ArithmeticShift should only hit this if they are all UNDEF. + auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; + if (llvm::all_of(ShiftAmts, OutOfRange)) { + SmallVector<Constant *, 8> ConstantVec; + for (int Idx : ShiftAmts) { + if (Idx < 0) { + ConstantVec.push_back(UndefValue::get(SVT)); + } else { + assert(LogicalShift && "Logical shift expected"); + ConstantVec.push_back(ConstantInt::getNullValue(SVT)); + } + } + return ConstantVector::get(ConstantVec); + } + + // We can't handle only some out of range values with generic logical shifts. + if (AnyOutOfRange) + return nullptr; + + // Build the shift amount constant vector. + SmallVector<Constant *, 8> ShiftVecAmts; + for (int Idx : ShiftAmts) { + if (Idx < 0) + ShiftVecAmts.push_back(UndefValue::get(SVT)); + else + ShiftVecAmts.push_back(ConstantInt::get(SVT, Idx)); + } + auto ShiftVec = ConstantVector::get(ShiftVecAmts); + + if (ShiftLeft) + return Builder.CreateShl(Vec, ShiftVec); + + if (LogicalShift) + return Builder.CreateLShr(Vec, ShiftVec); + + return Builder.CreateAShr(Vec, ShiftVec); +} + static Value *simplifyX86movmsk(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { Value *Arg = II.getArgOperand(0); @@ -1656,6 +1767,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + if (Value *V = simplifyX86varShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::x86_sse41_insertps: if (Value *V = simplifyX86insertps(*II, *Builder)) return replaceInstUsesWith(*II, V); |