diff options
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 96 |
1 files changed, 45 insertions, 51 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 51c72eb1837..b1bb9281ea2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -541,7 +541,8 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { +static Value *simplifyX86pack(IntrinsicInst &II, + InstCombiner::BuilderTy &Builder, bool IsSigned) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); Type *ResTy = II.getType(); @@ -552,68 +553,61 @@ static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { Type *ArgTy = Arg0->getType(); unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; - unsigned NumDstElts = ResTy->getVectorNumElements(); unsigned NumSrcElts = ArgTy->getVectorNumElements(); - assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types"); + assert(ResTy->getVectorNumElements() == (2 * NumSrcElts) && + "Unexpected packing types"); - unsigned NumDstEltsPerLane = NumDstElts / NumLanes; unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); - assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) && + unsigned SrcScalarSizeInBits = ArgTy->getScalarSizeInBits(); + assert(SrcScalarSizeInBits == (2 * DstScalarSizeInBits) && "Unexpected packing types"); // Constant folding. - auto *Cst0 = dyn_cast<Constant>(Arg0); - auto *Cst1 = dyn_cast<Constant>(Arg1); - if (!Cst0 || !Cst1) + if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) return nullptr; - SmallVector<Constant *, 32> Vals; - for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { - for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { - unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; - auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0; - auto *COp = Cst->getAggregateElement(SrcIdx); - if (COp && isa<UndefValue>(COp)) { - Vals.push_back(UndefValue::get(ResTy->getScalarType())); - continue; - } + // Clamp Values - signed/unsigned both use signed clamp values, but they + // differ on the min/max values. + APInt MinValue, MaxValue; + if (IsSigned) { + // PACKSS: Truncate signed value with signed saturation. + // Source values less than dst minint are saturated to minint. + // Source values greater than dst maxint are saturated to maxint. + MinValue = + APInt::getSignedMinValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); + MaxValue = + APInt::getSignedMaxValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); + } else { + // PACKUS: Truncate signed value with unsigned saturation. + // Source values less than zero are saturated to zero. + // Source values greater than dst maxuint are saturated to maxuint. + MinValue = APInt::getNullValue(SrcScalarSizeInBits); + MaxValue = APInt::getLowBitsSet(SrcScalarSizeInBits, DstScalarSizeInBits); + } - auto *CInt = dyn_cast_or_null<ConstantInt>(COp); - if (!CInt) - return nullptr; + auto *MinC = Constant::getIntegerValue(ArgTy, MinValue); + auto *MaxC = Constant::getIntegerValue(ArgTy, MaxValue); + Arg0 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg0, MinC), MinC, Arg0); + Arg1 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg1, MinC), MinC, Arg1); + Arg0 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg0, MaxC), MaxC, Arg0); + Arg1 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg1, MaxC), MaxC, Arg1); - APInt Val = CInt->getValue(); - assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() && - "Unexpected constant bitwidth"); - - if (IsSigned) { - // PACKSS: Truncate signed value with signed saturation. - // Source values less than dst minint are saturated to minint. - // Source values greater than dst maxint are saturated to maxint. - if (Val.isSignedIntN(DstScalarSizeInBits)) - Val = Val.trunc(DstScalarSizeInBits); - else if (Val.isNegative()) - Val = APInt::getSignedMinValue(DstScalarSizeInBits); - else - Val = APInt::getSignedMaxValue(DstScalarSizeInBits); - } else { - // PACKUS: Truncate signed value with unsigned saturation. - // Source values less than zero are saturated to zero. - // Source values greater than dst maxuint are saturated to maxuint. - if (Val.isIntN(DstScalarSizeInBits)) - Val = Val.trunc(DstScalarSizeInBits); - else if (Val.isNegative()) - Val = APInt::getNullValue(DstScalarSizeInBits); - else - Val = APInt::getAllOnesValue(DstScalarSizeInBits); - } + // Truncate clamped args to dst size. + auto *TruncTy = VectorType::get(ResTy->getScalarType(), NumSrcElts); + Arg0 = Builder.CreateTrunc(Arg0, TruncTy); + Arg1 = Builder.CreateTrunc(Arg1, TruncTy); - Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val)); - } + // Shuffle args together at the lane level. + SmallVector<unsigned, 32> PackMask; + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) + PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane)); + for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) + PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane) + NumSrcElts); } - return ConstantVector::get(Vals); + return Builder.CreateShuffleVector(Arg0, Arg1, PackMask); } // Replace X86-specific intrinsics with generic floor-ceil where applicable. @@ -2977,7 +2971,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_packsswb: case Intrinsic::x86_avx512_packssdw_512: case Intrinsic::x86_avx512_packsswb_512: - if (Value *V = simplifyX86pack(*II, true)) + if (Value *V = simplifyX86pack(*II, Builder, true)) return replaceInstUsesWith(*II, V); break; @@ -2987,7 +2981,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_packuswb: case Intrinsic::x86_avx512_packusdw_512: case Intrinsic::x86_avx512_packuswb_512: - if (Value *V = simplifyX86pack(*II, false)) + if (Value *V = simplifyX86pack(*II, Builder, false)) return replaceInstUsesWith(*II, V); break; |