diff options
Diffstat (limited to 'llvm/lib/Target/X86')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7572cf4a6ec..e827c5c35f1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -38816,6 +38816,127 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1, PMADDBuilder); } +// Attempt to turn this pattern into PMADDWD. +// (mul (add (zext (build_vector)), (zext (build_vector))), +// (add (zext (build_vector)), (zext (build_vector))) +static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1, + const SDLoc &DL, EVT VT, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + return SDValue(); + + if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 || + VT.getVectorNumElements() < 4 || + !isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + SDValue N10 = N1.getOperand(0); + SDValue N11 = N1.getOperand(1); + + // All inputs need to be sign extends. + // TODO: Support ZERO_EXTEND from known positive? + if (N00.getOpcode() != ISD::SIGN_EXTEND || + N01.getOpcode() != ISD::SIGN_EXTEND || + N10.getOpcode() != ISD::SIGN_EXTEND || + N11.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); + + // Peek through the extends. + N00 = N00.getOperand(0); + N01 = N01.getOperand(0); + N10 = N10.getOperand(0); + N11 = N11.getOperand(0); + + // Must be extending from vXi16. + EVT InVT = N00.getValueType(); + if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT || + N10.getValueType() != InVT || N11.getValueType() != InVT) + return SDValue(); + + // All inputs should be build_vectors. + if (N00.getOpcode() != ISD::BUILD_VECTOR || + N01.getOpcode() != ISD::BUILD_VECTOR || + N10.getOpcode() != ISD::BUILD_VECTOR || + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + // For each element, we need to ensure we have an odd element from one vector + // multiplied by the odd element of another vector and the even element from + // one of the same vectors being multiplied by the even element from the + // other vector. So we need to make sure for each element i, this operator + // is being performed: + // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] + SDValue In0, In1; + for (unsigned i = 0; i != N00.getNumOperands(); ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); + SDValue N10Elt = N10.getOperand(i); + SDValue N11Elt = N11.getOperand(i); + // TODO: Be more tolerant to undefs. + if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1)); + auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1)); + auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1)); + auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1)); + if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt) + return SDValue(); + unsigned IdxN00 = ConstN00Elt->getZExtValue(); + unsigned IdxN01 = ConstN01Elt->getZExtValue(); + unsigned IdxN10 = ConstN10Elt->getZExtValue(); + unsigned IdxN11 = ConstN11Elt->getZExtValue(); + // Add is commutative so indices can be reordered. + if (IdxN00 > IdxN10) { + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } + // N0 indices be the even elemtn. N1 indices must be the next odd element. + if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || + IdxN01 != 2 * i || IdxN11 != 2 * i + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); + SDValue N10In = N10Elt.getOperand(0); + SDValue N11In = N11Elt.getOperand(0); + // First time we find an input capture it. + if (!In0) { + In0 = N00In; + In1 = N01In; + } + // Mul is commutative so the input vectors can be in any order. + // Canonicalize to make the compares easier. + if (In0 != N00In) + std::swap(N00In, N01In); + if (In0 != N10In) + std::swap(N10In, N11In); + if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In) + return SDValue(); + } + + auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef<SDValue> Ops) { + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + EVT InVT = Ops[0].getValueType(); + assert(InVT.getScalarType() == MVT::i16 && + "Unexpected scalar element type"); + assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); + EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + InVT.getVectorNumElements() / 2); + return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]); + }; + return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 }, + PMADDBuilder); +} + static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags Flags = N->getFlags(); @@ -38831,6 +38952,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, SDLoc(N), VT, Subtarget)) return MAdd; + if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, SDLoc(N), VT, Subtarget)) + return MAdd; // Try to synthesize horizontal adds from adds of shuffles. if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || |

