diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 136 | 
1 files changed, 136 insertions, 0 deletions
| diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0579b979e25..c80bccf002a 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -29348,8 +29348,144 @@ static SDValue OptimizeConditionalInDecrement(SDNode *N, SelectionDAG &DAG) {                       DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp);  } +static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, +                                const X86Subtarget &Subtarget) { +  SDLoc DL(N); +  EVT VT = N->getValueType(0); +  SDValue Op0 = N->getOperand(0); +  SDValue Op1 = N->getOperand(1); + +  if (!VT.isVector() || !VT.isSimple() || +      !(VT.getVectorElementType() == MVT::i32)) +    return SDValue(); + +  unsigned RegSize = 128; +  if (Subtarget.hasBWI()) +    RegSize = 512; +  else if (Subtarget.hasAVX2()) +    RegSize = 256; + +  // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. +  if (VT.getSizeInBits() / 4 > RegSize) +    return SDValue(); + +  // Detect the following pattern: +  // +  // 1:    %2 = zext <N x i8> %0 to <N x i32> +  // 2:    %3 = zext <N x i8> %1 to <N x i32> +  // 3:    %4 = sub nsw <N x i32> %2, %3 +  // 4:    %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N] +  // 5:    %6 = sub nsw <N x i32> zeroinitializer, %4 +  // 6:    %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6 +  // 7:    %8 = add nsw <N x i32> %7, %vec.phi +  // +  // The last instruction must be a reduction add. The instructions 3-6 forms an +  // ABSDIFF pattern. + +  // The two operands of reduction add are from PHI and a select-op as in line 7 +  // above. +  SDValue SelectOp, Phi; +  if (Op0.getOpcode() == ISD::VSELECT) { +    SelectOp = Op0; +    Phi = Op1; +  } else if (Op1.getOpcode() == ISD::VSELECT) { +    SelectOp = Op1; +    Phi = Op0; +  } else +    return SDValue(); + +  // Check the condition of the select instruction is greater-than. +  SDValue SetCC = SelectOp->getOperand(0); +  if (SetCC.getOpcode() != ISD::SETCC) +    return SDValue(); +  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); +  if (CC != ISD::SETGT) +    return SDValue(); + +  Op0 = SelectOp->getOperand(1); +  Op1 = SelectOp->getOperand(2); + +  // The second operand of SelectOp Op1 is the negation of the first operand +  // Op0, which is implemented as 0 - Op0. +  if (!(Op1.getOpcode() == ISD::SUB && +        ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) && +        Op1.getOperand(1) == Op0)) +    return SDValue(); + +  // The first operand of SetCC is the first operand of SelectOp, which is the +  // difference between two input vectors. +  if (SetCC.getOperand(0) != Op0) +    return SDValue(); + +  // The second operand of > comparison can be either -1 or 0. +  if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || +        ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) +    return SDValue(); + +  // The first operand of SelectOp is the difference between two input vectors. +  if (Op0.getOpcode() != ISD::SUB) +    return SDValue(); + +  Op1 = Op0.getOperand(1); +  Op0 = Op0.getOperand(0); + +  // Check if the operands of the diff are zero-extended from vectors of i8. +  if (Op0.getOpcode() != ISD::ZERO_EXTEND || +      Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || +      Op1.getOpcode() != ISD::ZERO_EXTEND || +      Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) +    return SDValue(); + +  // SAD pattern detected. Now build a SAD instruction and an addition for +  // reduction. Note that the number of elments of the result of SAD is less +  // than the number of elements of its input. Therefore, we could only update +  // part of elements in the reduction vector. + +  // Legalize the type of the inputs of PSADBW. +  EVT InVT = Op0.getOperand(0).getValueType(); +  if (InVT.getSizeInBits() <= 128) +    RegSize = 128; +  else if (InVT.getSizeInBits() <= 256) +    RegSize = 256; + +  unsigned NumConcat = RegSize / InVT.getSizeInBits(); +  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT)); +  Ops[0] = Op0.getOperand(0); +  MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8); +  Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); +  Ops[0] = Op1.getOperand(0); +  Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + +  // The output of PSADBW is a vector of i64. +  MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64); +  SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, SadVT, Op0, Op1); + +  // We need to turn the vector of i64 into a vector of i32. +  MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); +  Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); + +  NumConcat = VT.getSizeInBits() / ResVT.getSizeInBits(); +  if (NumConcat > 1) { +    // Update part of elements of the reduction vector. This is done by first +    // extracting a sub-vector from it, updating this sub-vector, and inserting +    // it back. +    SDValue SubPhi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Phi, +                                 DAG.getIntPtrConstant(0, DL)); +    SDValue Res = DAG.getNode(ISD::ADD, DL, ResVT, Sad, SubPhi); +    return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Phi, Res, +                       DAG.getIntPtrConstant(0, DL)); +  } else +    return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); +} +  static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,                            const X86Subtarget &Subtarget) { +  const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; +  if (Flags->hasVectorReduction()) { +    if (SDValue Sad = detectSADPattern(N, DAG, Subtarget)) +      return Sad; +  } +    EVT VT = N->getValueType(0);    SDValue Op0 = N->getOperand(0);    SDValue Op1 = N->getOperand(1); | 

