diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 8a6c73732d5..f7dbf3c4238 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -36753,6 +36753,145 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, return DAG.getNode(Opc, DL, VT, LHS, RHS); } +// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes +// from one vector with signed bytes from another vector, adds together +// adjacent pairs of 16-bit products, and saturates the result before +// truncating to 16-bits. +// +// Which looks something like this: +// (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))), +// (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B)))))))) +static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + const SDLoc &DL) { + if (!VT.isVector() || !Subtarget.hasSSSE3()) + return SDValue(); + + unsigned NumElems = VT.getVectorNumElements(); + EVT ScalarVT = VT.getVectorElementType(); + if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems)) + return SDValue(); + + SDValue SSatVal = detectSSatPattern(In, VT); + if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) + return SDValue(); + + // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs + // of multiplies from even/odd elements. + SDValue N0 = SSatVal.getOperand(0); + SDValue N1 = SSatVal.getOperand(1); + + if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + SDValue N10 = N1.getOperand(0); + SDValue N11 = N1.getOperand(1); + + // TODO: Handle constant vectors and use knownbits/computenumsignbits? + // Canonicalize zero_extend to LHS. + if (N01.getOpcode() == ISD::ZERO_EXTEND) + std::swap(N00, N01); + if (N11.getOpcode() == ISD::ZERO_EXTEND) + std::swap(N10, N11); + + // Ensure we have a zero_extend and a sign_extend. + if (N00.getOpcode() != ISD::ZERO_EXTEND || + N01.getOpcode() != ISD::SIGN_EXTEND || + N10.getOpcode() != ISD::ZERO_EXTEND || + N11.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); + + // Peek through the extends. + N00 = N00.getOperand(0); + N01 = N01.getOperand(0); + N10 = N10.getOperand(0); + N11 = N11.getOperand(0); + + // Ensure the extend is from vXi8. + if (N00.getValueType().getVectorElementType() != MVT::i8 || + N01.getValueType().getVectorElementType() != MVT::i8 || + N10.getValueType().getVectorElementType() != MVT::i8 || + N11.getValueType().getVectorElementType() != MVT::i8) + return SDValue(); + + // All inputs should be build_vectors. + if (N00.getOpcode() != ISD::BUILD_VECTOR || + N01.getOpcode() != ISD::BUILD_VECTOR || + N10.getOpcode() != ISD::BUILD_VECTOR || + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + // N00/N10 are zero extended. N01/N11 are sign extended. + + // For each element, we need to ensure we have an odd element from one vector + // multiplied by the odd element of another vector and the even element from + // one of the same vectors being multiplied by the even element from the + // other vector. So we need to make sure for each element i, this operator + // is being performed: + // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] + SDValue ZExtIn, SExtIn; + for (unsigned i = 0; i != NumElems; ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); + SDValue N10Elt = N10.getOperand(i); + SDValue N11Elt = N11.getOperand(i); + // TODO: Be more tolerant to undefs. + if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1)); + auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1)); + auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1)); + auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1)); + if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt) + return SDValue(); + unsigned IdxN00 = ConstN00Elt->getZExtValue(); + unsigned IdxN01 = ConstN01Elt->getZExtValue(); + unsigned IdxN10 = ConstN10Elt->getZExtValue(); + unsigned IdxN11 = ConstN11Elt->getZExtValue(); + // Add is commutative so indices can be reordered. + if (IdxN00 > IdxN10) { + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } + // N0 indices be the even element. N1 indices must be the next odd element. + if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || + IdxN01 != 2 * i || IdxN11 != 2 * i + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); + SDValue N10In = N10Elt.getOperand(0); + SDValue N11In = N11Elt.getOperand(0); + // First time we find an input capture it. + if (!ZExtIn) { + ZExtIn = N00In; + SExtIn = N01In; + } + if (ZExtIn != N00In || SExtIn != N01In || + ZExtIn != N10In || SExtIn != N11In) + return SDValue(); + } + + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + EVT InVT = Ops[0].getValueType(); + assert(InVT.getScalarType() == MVT::i8 && + "Unexpected scalar element type"); + assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); + EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + InVT.getVectorNumElements() / 2); + return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]); + }; + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn }, + PMADDBuilder); +} + static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); @@ -36767,6 +36906,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) return Avg; + // Try to detect PMADD + if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL)) + return PMAdd; + // Try to combine truncation with signed/unsigned saturation. if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) return Val; |