diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 52 |
1 files changed, 45 insertions, 7 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 8f01210d666..5f08c9dd120 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -34109,6 +34109,41 @@ static SDValue detectUSatPattern(SDValue In, EVT VT) { return SDValue(); } +/// Detect patterns of truncation with signed saturation: +/// (truncate (smin ((smax (x, signed_min_of_dest_type)), +/// signed_max_of_dest_type)) to dest_type) +/// or: +/// (truncate (smax ((smin (x, signed_max_of_dest_type)), +/// signed_min_of_dest_type)) to dest_type). +/// Return the source value to be truncated or SDValue() if the pattern was not +/// matched. +static SDValue detectSSatPattern(SDValue In, EVT VT) { + unsigned NumDstBits = VT.getScalarSizeInBits(); + unsigned NumSrcBits = In.getScalarValueSizeInBits(); + assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation"); + + auto MatchMinMax = [](SDValue V, unsigned Opcode, const APInt &Limit) { + APInt C; + if (V.getOpcode() == Opcode && + ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit) + return V.getOperand(0); + return SDValue(); + }; + + APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits); + APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits); + + if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) + if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) + return SMax; + + if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) + if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) + return SMin; + + return SDValue(); +} + /// Detect a pattern of truncation with saturation: /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type). /// The types should allow to use VPMOVUS* instruction on AVX512. @@ -34121,15 +34156,18 @@ static SDValue detectAVX512USatPattern(SDValue In, EVT VT, return detectUSatPattern(In, VT); } -static SDValue combineTruncateWithUSat(SDValue In, EVT VT, const SDLoc &DL, - SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!TLI.isTypeLegal(In.getValueType()) || !TLI.isTypeLegal(VT)) return SDValue(); - if (auto USatVal = detectUSatPattern(In, VT)) - if (isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget)) + if (isSATValidOnAVX512Subtarget(In.getValueType(), VT, Subtarget)) { + if (auto SSatVal = detectSSatPattern(In, VT)) + return DAG.getNode(X86ISD::VTRUNCS, DL, VT, SSatVal); + if (auto USatVal = detectUSatPattern(In, VT)) return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal); + } return SDValue(); } @@ -35393,8 +35431,8 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) return Avg; - // Try to combine truncation with unsigned saturation. - if (SDValue Val = combineTruncateWithUSat(Src, VT, DL, DAG, Subtarget)) + // Try to combine truncation with signed/unsigned saturation. + if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) return Val; // The bitcast source is a direct mmx result. |