diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 34 |
1 files changed, 22 insertions, 12 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 52ce9ec18e3..17481f7fb26 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43151,18 +43151,6 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - SDValue MulOp = N->getOperand(0); - SDValue OtherOp = N->getOperand(1); - - if (MulOp.getOpcode() != ISD::MUL) - std::swap(MulOp, OtherOp); - if (MulOp.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); - EVT VT = N->getValueType(0); // If the vector size is less than 128, or greater than the supported RegSize, @@ -43170,6 +43158,28 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!VT.isVector() || VT.getVectorNumElements() < 8) return SDValue(); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + + auto UsePMADDWD = [&](SDValue Op) { + ShrinkMode Mode; + return Op.getOpcode() == ISD::MUL && + canReduceVMulWidth(Op.getNode(), DAG, Mode) && Mode != MULU16 && + (!Subtarget.hasSSE41() || + (Op->isOnlyUserOf(Op.getOperand(0).getNode()) && + Op->isOnlyUserOf(Op.getOperand(1).getNode()))); + }; + + SDValue MulOp, OtherOp; + if (UsePMADDWD(Op0)) { + MulOp = Op0; + OtherOp = Op1; + } else if (UsePMADDWD(Op1)) { + MulOp = Op1; + OtherOp = Op0; + } else + return SDValue(); + SDLoc DL(N); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); |