diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 78 |
1 files changed, 51 insertions, 27 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 75398c7ba7e..8db8ed8f2bc 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23435,6 +23435,14 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::FNMSUB_RND: return "X86ISD::FNMSUB_RND"; case X86ISD::FMADDSUB_RND: return "X86ISD::FMADDSUB_RND"; case X86ISD::FMSUBADD_RND: return "X86ISD::FMSUBADD_RND"; + case X86ISD::FMADDS1_RND: return "X86ISD::FMADDS1_RND"; + case X86ISD::FNMADDS1_RND: return "X86ISD::FNMADDS1_RND"; + case X86ISD::FMSUBS1_RND: return "X86ISD::FMSUBS1_RND"; + case X86ISD::FNMSUBS1_RND: return "X86ISD::FNMSUBS1_RND"; + case X86ISD::FMADDS3_RND: return "X86ISD::FMADDS3_RND"; + case X86ISD::FNMADDS3_RND: return "X86ISD::FNMADDS3_RND"; + case X86ISD::FMSUBS3_RND: return "X86ISD::FMSUBS3_RND"; + case X86ISD::FNMSUBS3_RND: return "X86ISD::FNMSUBS3_RND"; case X86ISD::VPMADD52H: return "X86ISD::VPMADD52H"; case X86ISD::VPMADD52L: return "X86ISD::VPMADD52L"; case X86ISD::VRNDSCALE: return "X86ISD::VRNDSCALE"; @@ -31709,14 +31717,17 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, unsigned NewOpcode = 0; if (Arg.hasOneUse()) { switch (Arg.getOpcode()) { - case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break; - case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break; - case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break; - case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break; - case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND; break; + case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break; + case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break; + case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND; break; + // We can't handle scalar intrinsic node here because it would only + // invert one element and not the whole vector. But we could try to handle + // a negation of the lower element only. } } if (NewOpcode) @@ -32250,15 +32261,6 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - auto isScalarMaskedNode = [&](SDValue &V) { - if (V.hasOneUse()) - return false; - for (auto User : V.getNode()->uses()) - if (User->getOpcode() == X86ISD::SELECTS && N->isOperandOf(User)) - return true; - return false; - }; - auto invertIfNegative = [](SDValue &V) { if (SDValue NegVal = isFNEG(V.getNode())) { V = NegVal; @@ -32267,10 +32269,11 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, return false; }; - // Do not convert scalar masked operations. - bool NegA = !isScalarMaskedNode(A) && invertIfNegative(A); - bool NegB = !isScalarMaskedNode(B) && invertIfNegative(B); - bool NegC = !isScalarMaskedNode(C) && invertIfNegative(C); + // Do not convert the passthru input of scalar intrinsics. + // FIXME: We could allow negations of the lower element only. + bool NegA = N->getOpcode() != X86ISD::FMADDS1_RND && invertIfNegative(A); + bool NegB = invertIfNegative(B); + bool NegC = N->getOpcode() != X86ISD::FMADDS3_RND && invertIfNegative(C); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); @@ -32281,16 +32284,35 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, else NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; + if (N->getOpcode() == X86ISD::FMADD_RND) { switch (NewOpcode) { - case X86ISD::FMADD: NewOpcode = X86ISD::FMADD_RND; break; - case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; - case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; + case X86ISD::FMADD: NewOpcode = X86ISD::FMADD_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; } - return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); + } else if (N->getOpcode() == X86ISD::FMADDS1_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADDS1_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1_RND; break; + } + } else if (N->getOpcode() == X86ISD::FMADDS3_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADDS3_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3_RND; break; + } + } else { + assert((N->getOpcode() == X86ISD::FMADD || N->getOpcode() == ISD::FMA) && + "Unexpected opcode!"); + return DAG.getNode(NewOpcode, dl, VT, A, B, C); } - return DAG.getNode(NewOpcode, dl, VT, A, B, C); + + return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); } static SDValue combineZext(SDNode *N, SelectionDAG &DAG, @@ -33057,6 +33079,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget); case X86ISD::FMADD: case X86ISD::FMADD_RND: + case X86ISD::FMADDS1_RND: + case X86ISD::FMADDS3_RND: case ISD::FMA: return combineFMA(N, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG); |