diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 50edb59eb8f..b520f0ac905 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -22326,7 +22326,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, assert(Subtarget.hasSSE2() && !Subtarget.hasSSE41() && "Should not custom lower when pmulld is available!"); - // If the upper 17 bits of each element are zero then we can use PMADD. + // If the upper 17 bits of each element are zero then we can use PMADDWD. APInt Mask17 = APInt::getHighBitsSet(32, 17); if (DAG.MaskedValueIsZero(A, Mask17) && DAG.MaskedValueIsZero(B, Mask17)) return DAG.getNode(X86ISD::VPMADDWD, dl, VT, @@ -32707,13 +32707,6 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, if ((NumElts % 2) != 0) return SDValue(); - // If the upper 17 bits of each element are zero then we can use PMADD. - APInt Mask17 = APInt::getHighBitsSet(32, 17); - if (VT == MVT::v4i32 && DAG.MaskedValueIsZero(N0, Mask17) && - DAG.MaskedValueIsZero(N1, Mask17)) - return DAG.getNode(X86ISD::VPMADDWD, DL, VT, DAG.getBitcast(MVT::v8i16, N0), - DAG.getBitcast(MVT::v8i16, N1)); - unsigned RegSize = 128; MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts); @@ -32885,6 +32878,25 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); + + // If the upper 17 bits of each element are zero then we can use PMADDWD, + // which is always at least as quick as PMULLD, expect on KNL. + if (Subtarget.getProcFamily() != X86Subtarget::IntelKNL && + ((VT == MVT::v4i32 && Subtarget.hasSSE2()) || + (VT == MVT::v8i32 && Subtarget.hasAVX2()) || + (VT == MVT::v16i32 && Subtarget.hasBWI()))) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + APInt Mask17 = APInt::getHighBitsSet(32, 17); + if (DAG.MaskedValueIsZero(N0, Mask17) && + DAG.MaskedValueIsZero(N1, Mask17)) { + unsigned NumElts = VT.getVectorNumElements(); + MVT WVT = MVT::getVectorVT(MVT::i16, 2 * NumElts); + return DAG.getNode(X86ISD::VPMADDWD, SDLoc(N), VT, + DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1)); + } + } + if (DCI.isBeforeLegalize() && VT.isVector()) return reduceVMULWidth(N, DAG, Subtarget); |

