summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp143
1 files changed, 143 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 8a6c73732d5..f7dbf3c4238 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36753,6 +36753,145 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
return DAG.getNode(Opc, DL, VT, LHS, RHS);
}
+// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
+// from one vector with signed bytes from another vector, adds together
+// adjacent pairs of 16-bit products, and saturates the result before
+// truncating to 16-bits.
+//
+// Which looks something like this:
+// (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))),
+// (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B))))))))
+static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget,
+ const SDLoc &DL) {
+ if (!VT.isVector() || !Subtarget.hasSSSE3())
+ return SDValue();
+
+ unsigned NumElems = VT.getVectorNumElements();
+ EVT ScalarVT = VT.getVectorElementType();
+ if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems))
+ return SDValue();
+
+ SDValue SSatVal = detectSSatPattern(In, VT);
+ if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
+ return SDValue();
+
+ // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
+ // of multiplies from even/odd elements.
+ SDValue N0 = SSatVal.getOperand(0);
+ SDValue N1 = SSatVal.getOperand(1);
+
+ if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+ SDValue N10 = N1.getOperand(0);
+ SDValue N11 = N1.getOperand(1);
+
+ // TODO: Handle constant vectors and use knownbits/computenumsignbits?
+ // Canonicalize zero_extend to LHS.
+ if (N01.getOpcode() == ISD::ZERO_EXTEND)
+ std::swap(N00, N01);
+ if (N11.getOpcode() == ISD::ZERO_EXTEND)
+ std::swap(N10, N11);
+
+ // Ensure we have a zero_extend and a sign_extend.
+ if (N00.getOpcode() != ISD::ZERO_EXTEND ||
+ N01.getOpcode() != ISD::SIGN_EXTEND ||
+ N10.getOpcode() != ISD::ZERO_EXTEND ||
+ N11.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ // Peek through the extends.
+ N00 = N00.getOperand(0);
+ N01 = N01.getOperand(0);
+ N10 = N10.getOperand(0);
+ N11 = N11.getOperand(0);
+
+ // Ensure the extend is from vXi8.
+ if (N00.getValueType().getVectorElementType() != MVT::i8 ||
+ N01.getValueType().getVectorElementType() != MVT::i8 ||
+ N10.getValueType().getVectorElementType() != MVT::i8 ||
+ N11.getValueType().getVectorElementType() != MVT::i8)
+ return SDValue();
+
+ // All inputs should be build_vectors.
+ if (N00.getOpcode() != ISD::BUILD_VECTOR ||
+ N01.getOpcode() != ISD::BUILD_VECTOR ||
+ N10.getOpcode() != ISD::BUILD_VECTOR ||
+ N11.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // N00/N10 are zero extended. N01/N11 are sign extended.
+
+ // For each element, we need to ensure we have an odd element from one vector
+ // multiplied by the odd element of another vector and the even element from
+ // one of the same vectors being multiplied by the even element from the
+ // other vector. So we need to make sure for each element i, this operator
+ // is being performed:
+ // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
+ SDValue ZExtIn, SExtIn;
+ for (unsigned i = 0; i != NumElems; ++i) {
+ SDValue N00Elt = N00.getOperand(i);
+ SDValue N01Elt = N01.getOperand(i);
+ SDValue N10Elt = N10.getOperand(i);
+ SDValue N11Elt = N11.getOperand(i);
+ // TODO: Be more tolerant to undefs.
+ if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+ auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
+ auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
+ auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
+ auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
+ if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+ return SDValue();
+ unsigned IdxN00 = ConstN00Elt->getZExtValue();
+ unsigned IdxN01 = ConstN01Elt->getZExtValue();
+ unsigned IdxN10 = ConstN10Elt->getZExtValue();
+ unsigned IdxN11 = ConstN11Elt->getZExtValue();
+ // Add is commutative so indices can be reordered.
+ if (IdxN00 > IdxN10) {
+ std::swap(IdxN00, IdxN10);
+ std::swap(IdxN01, IdxN11);
+ }
+ // N0 indices be the even element. N1 indices must be the next odd element.
+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
+ IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+ return SDValue();
+ SDValue N00In = N00Elt.getOperand(0);
+ SDValue N01In = N01Elt.getOperand(0);
+ SDValue N10In = N10Elt.getOperand(0);
+ SDValue N11In = N11Elt.getOperand(0);
+ // First time we find an input capture it.
+ if (!ZExtIn) {
+ ZExtIn = N00In;
+ SExtIn = N01In;
+ }
+ if (ZExtIn != N00In || SExtIn != N01In ||
+ ZExtIn != N10In || SExtIn != N11In)
+ return SDValue();
+ }
+
+ auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ // Shrink by adding truncate nodes and let DAGCombine fold with the
+ // sources.
+ EVT InVT = Ops[0].getValueType();
+ assert(InVT.getScalarType() == MVT::i8 &&
+ "Unexpected scalar element type");
+ assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
+ EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
+ InVT.getVectorNumElements() / 2);
+ return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn },
+ PMADDBuilder);
+}
+
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
@@ -36767,6 +36906,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL))
return Avg;
+ // Try to detect PMADD
+ if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL))
+ return PMAdd;
+
// Try to combine truncation with signed/unsigned saturation.
if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
return Val;
OpenPOWER on IntegriCloud