diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 30a9ebc8fc9..9ca25657270 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -36102,6 +36102,59 @@ static SDValue detectAddSubSatPattern(SDValue In, EVT VT, SelectionDAG &DAG, AddSubSatBuilder); } +// Try to form a MULHU or MULHS node by looking for +// (trunc (srl (mul ext, ext), 16)) +// TODO: This is X86 specific because we want to be able to handle wide types +// before type legalization. But we can only do it if the vector will be +// legalized via widening/splitting. Type legalization can't handle promotion +// of a MULHU/MULHS. There isn't a way to convey this to the generic DAG +// combiner. +static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, + SelectionDAG &DAG, const X86Subtarget &Subtarget) { + // First instruction should be a right shift of a multiply. + if (Src.getOpcode() != ISD::SRL || + Src.getOperand(0).getOpcode() != ISD::MUL) + return SDValue(); + + if (!Subtarget.hasSSE2()) + return SDValue(); + + // Only handle vXi16 types that are at least 128-bits. + if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 || + VT.getVectorNumElements() < 8) + return SDValue(); + + // Input type should be vXi32. + EVT InVT = Src.getValueType(); + if (InVT.getVectorElementType() != MVT::i32) + return SDValue(); + + // Need a shift by 16. + APInt ShiftAmt; + if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) || + ShiftAmt != 16) + return SDValue(); + + SDValue LHS = Src.getOperand(0).getOperand(0); + SDValue RHS = Src.getOperand(0).getOperand(1); + + unsigned ExtOpc = LHS.getOpcode(); + if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) || + RHS.getOpcode() != ExtOpc) + return SDValue(); + + // Peek through the extends. + LHS = LHS.getOperand(0); + RHS = RHS.getOperand(0); + + // Ensure the input types match. + if (LHS.getValueType() != VT || RHS.getValueType() != VT) + return SDValue(); + + unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU; + return DAG.getNode(Opc, DL, VT, LHS, RHS); +} + static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); @@ -36124,6 +36177,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) return Val; + // Try to combine PMULHUW/PMULHW for vXi16. + if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget)) + return V; + // The bitcast source is a direct mmx result. // Detect bitcasts between i32 to x86mmx if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) { |

