diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index de413c42348..600c8c36392 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -199,7 +199,9 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { static Value *SimplifyX86immshift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder, - bool ShiftLeft) { + bool LogicalShift, bool ShiftLeft) { + assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); + // Simplify if count is constant. auto Arg1 = II.getArgOperand(1); auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1); @@ -238,9 +240,15 @@ static Value *SimplifyX86immshift(const IntrinsicInst &II, if (Count == 0) return Vec; - // Handle cases when Shift >= BitWidth - just return zero. - if (Count.uge(BitWidth)) - return ConstantAggregateZero::get(VT); + // Handle cases when Shift >= BitWidth. + if (Count.uge(BitWidth)) { + // If LogicalShift - just return zero. + if (LogicalShift) + return ConstantAggregateZero::get(VT); + + // If ArithmeticShift - clamp Shift to (BitWidth - 1). + Count = APInt(64, BitWidth - 1); + } // Get a constant vector of the same type as the first operand. auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth)); @@ -249,7 +257,10 @@ static Value *SimplifyX86immshift(const IntrinsicInst &II, if (ShiftLeft) return Builder.CreateShl(Vec, ShiftVec); - return Builder.CreateLShr(Vec, ShiftVec); + if (LogicalShift) + return Builder.CreateLShr(Vec, ShiftVec); + + return Builder.CreateAShr(Vec, ShiftVec); } static Value *SimplifyX86extend(const IntrinsicInst &II, @@ -776,6 +787,19 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + // Constant fold ashr( <A x Bi>, Ci ). + case Intrinsic::x86_sse2_psra_d: + case Intrinsic::x86_sse2_psra_w: + case Intrinsic::x86_sse2_psrai_d: + case Intrinsic::x86_sse2_psrai_w: + case Intrinsic::x86_avx2_psra_d: + case Intrinsic::x86_avx2_psra_w: + case Intrinsic::x86_avx2_psrai_d: + case Intrinsic::x86_avx2_psrai_w: + if (Value *V = SimplifyX86immshift(*II, *Builder, false, false)) + return ReplaceInstUsesWith(*II, V); + break; + // Constant fold lshr( <A x Bi>, Ci ). case Intrinsic::x86_sse2_psrl_d: case Intrinsic::x86_sse2_psrl_q: @@ -789,7 +813,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_psrli_d: case Intrinsic::x86_avx2_psrli_q: case Intrinsic::x86_avx2_psrli_w: - if (Value *V = SimplifyX86immshift(*II, *Builder, false)) + if (Value *V = SimplifyX86immshift(*II, *Builder, true, false)) return ReplaceInstUsesWith(*II, V); break; @@ -806,7 +830,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: - if (Value *V = SimplifyX86immshift(*II, *Builder, true)) + if (Value *V = SimplifyX86immshift(*II, *Builder, true, true)) return ReplaceInstUsesWith(*II, V); break; |