diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 130 |
1 files changed, 128 insertions, 2 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 9eb8d5d1aea..9e046c9e3a6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -576,6 +576,105 @@ static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { return ConstantVector::get(Vals); } +// Replace X86-specific intrinsics with generic floor-ceil where applicable. +static Value *simplifyX86round(IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + ConstantInt *Arg = nullptr; + Intrinsic::ID IntrinsicID = II.getIntrinsicID(); + + if (IntrinsicID == Intrinsic::x86_sse41_round_ss || + IntrinsicID == Intrinsic::x86_sse41_round_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(2)); + else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); + else + Arg = dyn_cast<ConstantInt>(II.getArgOperand(1)); + if (!Arg) + return nullptr; + unsigned RoundControl = Arg->getZExtValue(); + + Arg = nullptr; + unsigned SAE = 0; + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); + else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(5)); + else + SAE = 4; + if (!SAE) { + if (!Arg) + return nullptr; + SAE = Arg->getZExtValue(); + } + + if (SAE != 4 || (RoundControl != 2 /*ceil*/ && RoundControl != 1 /*floor*/)) + return nullptr; + + Value *Src, *Dst, *Mask; + bool IsScalar = false; + if (IntrinsicID == Intrinsic::x86_sse41_round_ss || + IntrinsicID == Intrinsic::x86_sse41_round_sd || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + IsScalar = true; + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + Mask = II.getArgOperand(3); + Value *Zero = Constant::getNullValue(Mask->getType()); + Mask = Builder.CreateAnd(Mask, 1); + Mask = Builder.CreateICmp(ICmpInst::ICMP_NE, Mask, Zero); + Dst = II.getArgOperand(2); + } else + Dst = II.getArgOperand(0); + Src = Builder.CreateExtractElement(II.getArgOperand(1), (uint64_t)0); + } else { + Src = II.getArgOperand(0); + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_128 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_256 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_128 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_256 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) { + Dst = II.getArgOperand(2); + Mask = II.getArgOperand(3); + } else { + Dst = Src; + Mask = ConstantInt::getAllOnesValue( + Builder.getIntNTy(Src->getType()->getVectorNumElements())); + } + } + + Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor; + Value *Res = Builder.CreateIntrinsic(ID, {Src}, &II); + if (!IsScalar) { + if (auto *C = dyn_cast<Constant>(Mask)) + if (C->isAllOnesValue()) + return Res; + auto *MaskTy = VectorType::get( + Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder.CreateBitCast(Mask, MaskTy); + unsigned Width = Src->getType()->getVectorNumElements(); + if (MaskTy->getVectorNumElements() > Width) { + uint32_t Indices[4]; + for (unsigned i = 0; i != Width; ++i) + Indices[i] = i; + Mask = Builder.CreateShuffleVector(Mask, Mask, + makeArrayRef(Indices, Width)); + } + return Builder.CreateSelect(Mask, Res, Dst); + } + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + Dst = Builder.CreateExtractElement(Dst, (uint64_t)0); + Res = Builder.CreateSelect(Mask, Res, Dst); + Dst = II.getArgOperand(0); + } + return Builder.CreateInsertElement(Dst, Res, (uint64_t)0); +} + static Value *simplifyX86movmsk(const IntrinsicInst &II) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); @@ -2222,6 +2321,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_sse41_round_ps: + case Intrinsic::x86_sse41_round_pd: + case Intrinsic::x86_avx_round_ps_256: + case Intrinsic::x86_avx_round_pd_256: + case Intrinsic::x86_avx512_mask_rndscale_ps_128: + case Intrinsic::x86_avx512_mask_rndscale_ps_256: + case Intrinsic::x86_avx512_mask_rndscale_ps_512: + case Intrinsic::x86_avx512_mask_rndscale_pd_128: + case Intrinsic::x86_avx512_mask_rndscale_pd_256: + case Intrinsic::x86_avx512_mask_rndscale_pd_512: + case Intrinsic::x86_avx512_mask_rndscale_ss: + case Intrinsic::x86_avx512_mask_rndscale_sd: + if (Value *V = simplifyX86round(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::x86_mmx_pmovmskb: case Intrinsic::x86_sse_movmsk_ps: case Intrinsic::x86_sse2_movmsk_pd: @@ -2438,8 +2553,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_cmp_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: case Intrinsic::x86_xop_vfrcz_ss: case Intrinsic::x86_xop_vfrcz_sd: { unsigned VWidth = II->getType()->getVectorNumElements(); @@ -2452,6 +2565,19 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } else if (Value *V = simplifyX86round(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + } // Constant fold ashr( <A x Bi>, Ci ). // Constant fold lshr( <A x Bi>, Ci ). |

