summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp57
1 files changed, 57 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 30a9ebc8fc9..9ca25657270 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36102,6 +36102,59 @@ static SDValue detectAddSubSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
AddSubSatBuilder);
}
+// Try to form a MULHU or MULHS node by looking for
+// (trunc (srl (mul ext, ext), 16))
+// TODO: This is X86 specific because we want to be able to handle wide types
+// before type legalization. But we can only do it if the vector will be
+// legalized via widening/splitting. Type legalization can't handle promotion
+// of a MULHU/MULHS. There isn't a way to convey this to the generic DAG
+// combiner.
+static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
+ SelectionDAG &DAG, const X86Subtarget &Subtarget) {
+ // First instruction should be a right shift of a multiply.
+ if (Src.getOpcode() != ISD::SRL ||
+ Src.getOperand(0).getOpcode() != ISD::MUL)
+ return SDValue();
+
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ // Only handle vXi16 types that are at least 128-bits.
+ if (!VT.isVector() || VT.getVectorElementType() != MVT::i16 ||
+ VT.getVectorNumElements() < 8)
+ return SDValue();
+
+ // Input type should be vXi32.
+ EVT InVT = Src.getValueType();
+ if (InVT.getVectorElementType() != MVT::i32)
+ return SDValue();
+
+ // Need a shift by 16.
+ APInt ShiftAmt;
+ if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) ||
+ ShiftAmt != 16)
+ return SDValue();
+
+ SDValue LHS = Src.getOperand(0).getOperand(0);
+ SDValue RHS = Src.getOperand(0).getOperand(1);
+
+ unsigned ExtOpc = LHS.getOpcode();
+ if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
+ RHS.getOpcode() != ExtOpc)
+ return SDValue();
+
+ // Peek through the extends.
+ LHS = LHS.getOperand(0);
+ RHS = RHS.getOperand(0);
+
+ // Ensure the input types match.
+ if (LHS.getValueType() != VT || RHS.getValueType() != VT)
+ return SDValue();
+
+ unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU;
+ return DAG.getNode(Opc, DL, VT, LHS, RHS);
+}
+
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
@@ -36124,6 +36177,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
return Val;
+ // Try to combine PMULHUW/PMULHW for vXi16.
+ if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
+ return V;
+
// The bitcast source is a direct mmx result.
// Detect bitcasts between i32 to x86mmx
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {
OpenPOWER on IntegriCloud