summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp38
1 files changed, 31 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index de413c42348..600c8c36392 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -199,7 +199,9 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
static Value *SimplifyX86immshift(const IntrinsicInst &II,
InstCombiner::BuilderTy &Builder,
- bool ShiftLeft) {
+ bool LogicalShift, bool ShiftLeft) {
+ assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
+
// Simplify if count is constant.
auto Arg1 = II.getArgOperand(1);
auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1);
@@ -238,9 +240,15 @@ static Value *SimplifyX86immshift(const IntrinsicInst &II,
if (Count == 0)
return Vec;
- // Handle cases when Shift >= BitWidth - just return zero.
- if (Count.uge(BitWidth))
- return ConstantAggregateZero::get(VT);
+ // Handle cases when Shift >= BitWidth.
+ if (Count.uge(BitWidth)) {
+ // If LogicalShift - just return zero.
+ if (LogicalShift)
+ return ConstantAggregateZero::get(VT);
+
+ // If ArithmeticShift - clamp Shift to (BitWidth - 1).
+ Count = APInt(64, BitWidth - 1);
+ }
// Get a constant vector of the same type as the first operand.
auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth));
@@ -249,7 +257,10 @@ static Value *SimplifyX86immshift(const IntrinsicInst &II,
if (ShiftLeft)
return Builder.CreateShl(Vec, ShiftVec);
- return Builder.CreateLShr(Vec, ShiftVec);
+ if (LogicalShift)
+ return Builder.CreateLShr(Vec, ShiftVec);
+
+ return Builder.CreateAShr(Vec, ShiftVec);
}
static Value *SimplifyX86extend(const IntrinsicInst &II,
@@ -776,6 +787,19 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
+ // Constant fold ashr( <A x Bi>, Ci ).
+ case Intrinsic::x86_sse2_psra_d:
+ case Intrinsic::x86_sse2_psra_w:
+ case Intrinsic::x86_sse2_psrai_d:
+ case Intrinsic::x86_sse2_psrai_w:
+ case Intrinsic::x86_avx2_psra_d:
+ case Intrinsic::x86_avx2_psra_w:
+ case Intrinsic::x86_avx2_psrai_d:
+ case Intrinsic::x86_avx2_psrai_w:
+ if (Value *V = SimplifyX86immshift(*II, *Builder, false, false))
+ return ReplaceInstUsesWith(*II, V);
+ break;
+
// Constant fold lshr( <A x Bi>, Ci ).
case Intrinsic::x86_sse2_psrl_d:
case Intrinsic::x86_sse2_psrl_q:
@@ -789,7 +813,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_avx2_psrli_d:
case Intrinsic::x86_avx2_psrli_q:
case Intrinsic::x86_avx2_psrli_w:
- if (Value *V = SimplifyX86immshift(*II, *Builder, false))
+ if (Value *V = SimplifyX86immshift(*II, *Builder, true, false))
return ReplaceInstUsesWith(*II, V);
break;
@@ -806,7 +830,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_avx2_pslli_d:
case Intrinsic::x86_avx2_pslli_q:
case Intrinsic::x86_avx2_pslli_w:
- if (Value *V = SimplifyX86immshift(*II, *Builder, true))
+ if (Value *V = SimplifyX86immshift(*II, *Builder, true, true))
return ReplaceInstUsesWith(*II, V);
break;
OpenPOWER on IntegriCloud