diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c7b01fa4eb4..fb31ce6d78e 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -34618,6 +34618,51 @@ static SDValue combineAddOrSubToADCOrSBB(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, VT), NewCmp); } +static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDValue MulOp = N->getOperand(0); + SDValue Phi = N->getOperand(1); + + if (MulOp.getOpcode() != ISD::MUL) + std::swap(MulOp, Phi); + if (MulOp.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode)) + return SDValue(); + + EVT VT = N->getValueType(0); + + unsigned RegSize = 128; + if (Subtarget.hasBWI()) + RegSize = 512; + else if (Subtarget.hasAVX2()) + RegSize = 256; + unsigned VectorSize = VT.getVectorNumElements() * 16; + // If the vector size is less than 128, or greater than the supported RegSize, + // do not use PMADD. + if (VectorSize < 128 || VectorSize > RegSize) + return SDValue(); + + SDLoc DL(N); + EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + VT.getVectorNumElements()); + EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + VT.getVectorNumElements() / 2); + + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); + + // Madd vector size is half of the original vector size + SDValue Madd = DAG.getNode(X86ISD::VPMADDWD, DL, MAddVT, N0, N1); + // Fill the rest of the output with 0 + SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); + return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi); +} + static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDLoc DL(N); @@ -34695,6 +34740,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, if (Flags->hasVectorReduction()) { if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) return Sad; + if (SDValue MAdd = combineLoopMAddPattern(N, DAG, Subtarget)) + return MAdd; } EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); |

