summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp123
1 files changed, 123 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7572cf4a6ec..e827c5c35f1 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -38816,6 +38816,127 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
PMADDBuilder);
}
+// Attempt to turn this pattern into PMADDWD.
+// (mul (add (zext (build_vector)), (zext (build_vector))),
+// (add (zext (build_vector)), (zext (build_vector)))
+static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
+ const SDLoc &DL, EVT VT,
+ const X86Subtarget &Subtarget) {
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
+ VT.getVectorNumElements() < 4 ||
+ !isPowerOf2_32(VT.getVectorNumElements()))
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+ SDValue N10 = N1.getOperand(0);
+ SDValue N11 = N1.getOperand(1);
+
+ // All inputs need to be sign extends.
+ // TODO: Support ZERO_EXTEND from known positive?
+ if (N00.getOpcode() != ISD::SIGN_EXTEND ||
+ N01.getOpcode() != ISD::SIGN_EXTEND ||
+ N10.getOpcode() != ISD::SIGN_EXTEND ||
+ N11.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ // Peek through the extends.
+ N00 = N00.getOperand(0);
+ N01 = N01.getOperand(0);
+ N10 = N10.getOperand(0);
+ N11 = N11.getOperand(0);
+
+ // Must be extending from vXi16.
+ EVT InVT = N00.getValueType();
+ if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
+ N10.getValueType() != InVT || N11.getValueType() != InVT)
+ return SDValue();
+
+ // All inputs should be build_vectors.
+ if (N00.getOpcode() != ISD::BUILD_VECTOR ||
+ N01.getOpcode() != ISD::BUILD_VECTOR ||
+ N10.getOpcode() != ISD::BUILD_VECTOR ||
+ N11.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // For each element, we need to ensure we have an odd element from one vector
+ // multiplied by the odd element of another vector and the even element from
+ // one of the same vectors being multiplied by the even element from the
+ // other vector. So we need to make sure for each element i, this operator
+ // is being performed:
+ // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
+ SDValue In0, In1;
+ for (unsigned i = 0; i != N00.getNumOperands(); ++i) {
+ SDValue N00Elt = N00.getOperand(i);
+ SDValue N01Elt = N01.getOperand(i);
+ SDValue N10Elt = N10.getOperand(i);
+ SDValue N11Elt = N11.getOperand(i);
+ // TODO: Be more tolerant to undefs.
+ if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+ auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
+ auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
+ auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
+ auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
+ if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+ return SDValue();
+ unsigned IdxN00 = ConstN00Elt->getZExtValue();
+ unsigned IdxN01 = ConstN01Elt->getZExtValue();
+ unsigned IdxN10 = ConstN10Elt->getZExtValue();
+ unsigned IdxN11 = ConstN11Elt->getZExtValue();
+ // Add is commutative so indices can be reordered.
+ if (IdxN00 > IdxN10) {
+ std::swap(IdxN00, IdxN10);
+ std::swap(IdxN01, IdxN11);
+ }
+ // N0 indices be the even elemtn. N1 indices must be the next odd element.
+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
+ IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+ return SDValue();
+ SDValue N00In = N00Elt.getOperand(0);
+ SDValue N01In = N01Elt.getOperand(0);
+ SDValue N10In = N10Elt.getOperand(0);
+ SDValue N11In = N11Elt.getOperand(0);
+ // First time we find an input capture it.
+ if (!In0) {
+ In0 = N00In;
+ In1 = N01In;
+ }
+ // Mul is commutative so the input vectors can be in any order.
+ // Canonicalize to make the compares easier.
+ if (In0 != N00In)
+ std::swap(N00In, N01In);
+ if (In0 != N10In)
+ std::swap(N10In, N11In);
+ if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In)
+ return SDValue();
+ }
+
+ auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ // Shrink by adding truncate nodes and let DAGCombine fold with the
+ // sources.
+ EVT InVT = Ops[0].getValueType();
+ assert(InVT.getScalarType() == MVT::i16 &&
+ "Unexpected scalar element type");
+ assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
+ EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
+ InVT.getVectorNumElements() / 2);
+ return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 },
+ PMADDBuilder);
+}
+
static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
const SDNodeFlags Flags = N->getFlags();
@@ -38831,6 +38952,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
return MAdd;
+ if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget))
+ return MAdd;
// Try to synthesize horizontal adds from adds of shuffles.
if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 ||
OpenPOWER on IntegriCloud