diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 59 |
1 files changed, 42 insertions, 17 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index becf658956d..b76ba58b6c0 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -32906,6 +32906,46 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, return SDValue(); } +// If the upper 17 bits of each element are zero then we can use PMADDWD, +// which is always at least as quick as PMULLD, expect on KNL. +static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + + if (Subtarget.getProcFamily() == X86Subtarget::IntelKNL) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Only support vXi32 vectors. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i32) + return SDValue(); + + // Make sure the vXi16 type is legal. This covers the AVX512 without BWI case. + MVT WVT = MVT::getVectorVT(MVT::i16, 2 * VT.getVectorNumElements()); + if (!DAG.getTargetLoweringInfo().isTypeLegal(WVT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + APInt Mask17 = APInt::getHighBitsSet(32, 17); + if (!DAG.MaskedValueIsZero(N1, Mask17) || + !DAG.MaskedValueIsZero(N0, Mask17)) + return SDValue(); + + // Use SplitBinaryOpsAndApply to handle AVX splitting. + auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, SDValue Op0, + SDValue Op1) { + MVT VT = MVT::getVectorVT(MVT::i32, Op0.getValueSizeInBits() / 32); + return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Op0, Op1); + }; + return SplitBinaryOpsAndApply(DAG, Subtarget, SDLoc(N), VT, + DAG.getBitcast(WVT, N0), + DAG.getBitcast(WVT, N1), PMADDWDBuilder); +} + /// Optimize a single multiply with constant into two operations in order to /// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA. static SDValue combineMul(SDNode *N, SelectionDAG &DAG, @@ -32913,23 +32953,8 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); - // If the upper 17 bits of each element are zero then we can use PMADDWD, - // which is always at least as quick as PMULLD, expect on KNL. - if (Subtarget.getProcFamily() != X86Subtarget::IntelKNL && - ((VT == MVT::v4i32 && Subtarget.hasSSE2()) || - (VT == MVT::v8i32 && Subtarget.hasAVX2()) || - (VT == MVT::v16i32 && Subtarget.useBWIRegs()))) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - APInt Mask17 = APInt::getHighBitsSet(32, 17); - if (DAG.MaskedValueIsZero(N0, Mask17) && - DAG.MaskedValueIsZero(N1, Mask17)) { - unsigned NumElts = VT.getVectorNumElements(); - MVT WVT = MVT::getVectorVT(MVT::i16, 2 * NumElts); - return DAG.getNode(X86ISD::VPMADDWD, SDLoc(N), VT, - DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1)); - } - } + if (SDValue V = combineMulToPMADDWD(N, DAG, DCI, Subtarget)) + return V; if (DCI.isBeforeLegalize() && VT.isVector()) return reduceVMULWidth(N, DAG, Subtarget); |