diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 2f3f1b86c11..f4bf270d738 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35997,6 +35997,91 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL, return SDValue(); } +/// This function detects the addition or substraction with saturation pattern +/// between 2 unsigned i8/i16 vectors and replace this operation with the +/// efficient X86ISD::ADDUS/X86ISD::ADDS/X86ISD::SUBUS/x86ISD::SUBS instruction. +static SDValue detectAddSubSatPattern(SDValue In, EVT VT, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + const SDLoc &DL) { + if (!VT.isVector() || !VT.isSimple()) + return SDValue(); + EVT InVT = In.getValueType(); + unsigned NumElems = VT.getVectorNumElements(); + + EVT ScalarVT = VT.getVectorElementType(); + if ((ScalarVT != MVT::i8 && ScalarVT != MVT::i16) || + InVT.getSizeInBits() % 128 != 0 || !isPowerOf2_32(NumElems)) + return SDValue(); + + // InScalarVT is the intermediate type in AddSubSat pattern + // and it should be greater than the original input type (i8/i16). + EVT InScalarVT = InVT.getVectorElementType(); + if (InScalarVT.getSizeInBits() <= ScalarVT.getSizeInBits()) + return SDValue(); + + if (!Subtarget.hasSSE2()) + return SDValue(); + + // Detect the following pattern: + // %2 = zext <16 x i8> %0 to <16 x i16> + // %3 = zext <16 x i8> %1 to <16 x i16> + // %4 = add nuw nsw <16 x i16> %3, %2 + // %5 = icmp ult <16 x i16> %4, <16 x i16> (vector of max InScalarVT values) + // %6 = select <16 x i1> %5, <16 x i16> (vector of max InScalarVT values) + // %7 = trunc <16 x i16> %6 to <16 x i8> + + // Detect a Sat Pattern + bool Signed = true; + SDValue Sat = detectSSatPattern(In, VT, false); + if (!Sat) { + Sat = detectUSatPattern(In, VT); + Signed = false; + } + if (!Sat) + return SDValue(); + if (Sat.getOpcode() != ISD::ADD && Sat.getOpcode() != ISD::SUB) + return SDValue(); + + unsigned Opcode = Sat.getOpcode() == ISD::ADD ? Signed ? X86ISD::ADDS + : X86ISD::ADDUS + : Signed ? X86ISD::SUBS + : X86ISD::SUBUS; + + // Get addition elements. + SDValue LHS = Sat.getOperand(0); + SDValue RHS = Sat.getOperand(1); + + // Check if LHS and RHS are results of type promotion or + // one of them is and the other one is constant. + unsigned ExtendOpcode = Signed ? ISD::SIGN_EXTEND : + ISD::ZERO_EXTEND; + unsigned LHSOpcode = LHS.getOpcode(); + unsigned RHSOpcode = RHS.getOpcode(); + + if (LHSOpcode == ExtendOpcode && RHSOpcode == ExtendOpcode) { + LHS = LHS.getOperand(0); + RHS = RHS.getOperand(0); + } else if (LHSOpcode == ExtendOpcode && + ISD::isBuildVectorOfConstantSDNodes(RHS.getNode())) { + LHS = LHS.getOperand(0); + RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS); + } else if (RHSOpcode == ExtendOpcode && + ISD::isBuildVectorOfConstantSDNodes(LHS.getNode())) { + RHS = RHS.getOperand(0); + LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS); + } else + return SDValue(); + + // The pattern is detected, emit ADDS/ADDUS/SUBS/SUBUS instruction. + auto AddSubSatBuilder = [Opcode](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + EVT VT = Ops[0].getValueType(); + return DAG.getNode(Opcode, DL, VT, Ops); + }; + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { LHS, RHS }, + AddSubSatBuilder); +} + static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); @@ -36011,6 +36096,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) return Avg; + // Try to detect addition or substraction with saturation. + if (SDValue AddSubSat = detectAddSubSatPattern(Src, VT, DAG, Subtarget, DL)) + return AddSubSat; + // Try to combine truncation with signed/unsigned saturation. if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) return Val; |

