diff options
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 31 | ||||
| -rw-r--r-- | llvm/test/CodeGen/X86/fma-fneg-combine.ll | 39 |
2 files changed, 68 insertions, 2 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 1290d22a561..dc9bfa3296b 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35474,6 +35474,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + // TODO: Handle FMSUB/FNMADD/FNMSUB as the starting opcode. SDLoc dl(N); EVT VT = N->getValueType(0); @@ -35574,6 +35575,32 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Combine FMADDSUB(A, B, FNEG(C)) -> FMSUBADD(A, B, C) +static SDValue combineFMADDSUB(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDLoc dl(N); + EVT VT = N->getValueType(0); + + SDValue NegVal = isFNEG(N->getOperand(2).getNode()); + if (!NegVal) + return SDValue(); + + unsigned NewOpcode; + switch (N->getOpcode()) { + default: llvm_unreachable("Unexpected opcode!"); + case X86ISD::FMADDSUB: NewOpcode = X86ISD::FMSUBADD; break; + case X86ISD::FMADDSUB_RND: NewOpcode = X86ISD::FMSUBADD_RND; break; + case X86ISD::FMSUBADD: NewOpcode = X86ISD::FMADDSUB; break; + case X86ISD::FMSUBADD_RND: NewOpcode = X86ISD::FMADDSUB_RND; break; + } + + if (N->getNumOperands() == 4) + return DAG.getNode(NewOpcode, dl, VT, N->getOperand(0), N->getOperand(1), + NegVal, N->getOperand(3)); + return DAG.getNode(NewOpcode, dl, VT, N->getOperand(0), N->getOperand(1), + NegVal); +} + static SDValue combineZext(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -36868,6 +36895,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::FMADDS1: case X86ISD::FMADDS3: case ISD::FMA: return combineFMA(N, DAG, Subtarget); + case X86ISD::FMADDSUB_RND: + case X86ISD::FMSUBADD_RND: + case X86ISD::FMADDSUB: + case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG); case X86ISD::TESTM: return combineTestM(N, DAG, Subtarget); diff --git a/llvm/test/CodeGen/X86/fma-fneg-combine.ll b/llvm/test/CodeGen/X86/fma-fneg-combine.ll index 6a31bc0f103..8247cb27978 100644 --- a/llvm/test/CodeGen/X86/fma-fneg-combine.ll +++ b/llvm/test/CodeGen/X86/fma-fneg-combine.ll @@ -97,7 +97,7 @@ define <8 x float> @test8(<8 x float> %a, <8 x float> %b, <8 x float> %c) { ; ; KNL-LABEL: test8: ; KNL: # BB#0: # %entry -; KNL-NEXT: vbroadcastss {{.*}}(%rip), %ymm3 +; KNL-NEXT: vbroadcastss {{.*#+}} ymm3 = [-0,-0,-0,-0,-0,-0,-0,-0] ; KNL-NEXT: vxorps %ymm3, %ymm2, %ymm2 ; KNL-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 ; KNL-NEXT: retq @@ -147,7 +147,7 @@ define <4 x float> @test11(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 ze ; ; KNL-LABEL: test11: ; KNL: # BB#0: # %entry -; KNL-NEXT: vbroadcastss {{.*}}(%rip), %xmm0 +; KNL-NEXT: vbroadcastss {{.*#+}} xmm0 = [-0,-0,-0,-0] ; KNL-NEXT: vxorps %xmm0, %xmm2, %xmm0 ; KNL-NEXT: kmovw %edi, %k1 ; KNL-NEXT: vfmadd231ss %xmm1, %xmm1, %xmm0 {%k1} @@ -270,3 +270,38 @@ entry: ret <16 x float> %1 } +define <16 x float> @test16(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 %mask) { +; SKX-LABEL: test16: +; SKX: # BB#0: +; SKX-NEXT: kmovd %edi, %k1 +; SKX-NEXT: vfmsubadd132ps {rd-sae}, %zmm1, %zmm2, %zmm0 {%k1} +; SKX-NEXT: retq +; +; KNL-LABEL: test16: +; KNL: # BB#0: +; KNL-NEXT: kmovw %edi, %k1 +; KNL-NEXT: vfmsubadd132ps {rd-sae}, %zmm1, %zmm2, %zmm0 {%k1} +; KNL-NEXT: retq + %sub.i = fsub <16 x float> <float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %c + %res = call <16 x float> @llvm.x86.avx512.mask.vfmaddsub.ps.512(<16 x float> %a, <16 x float> %b, <16 x float> %sub.i, i16 %mask, i32 1) + ret <16 x float> %res +} +declare <16 x float> @llvm.x86.avx512.mask.vfmaddsub.ps.512(<16 x float>, <16 x float>, <16 x float>, i16, i32) + +define <8 x double> @test17(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8 %mask) { +; SKX-LABEL: test17: +; SKX: # BB#0: +; SKX-NEXT: kmovd %edi, %k1 +; SKX-NEXT: vfmsubadd132pd %zmm1, %zmm2, %zmm0 {%k1} +; SKX-NEXT: retq +; +; KNL-LABEL: test17: +; KNL: # BB#0: +; KNL-NEXT: kmovw %edi, %k1 +; KNL-NEXT: vfmsubadd132pd %zmm1, %zmm2, %zmm0 {%k1} +; KNL-NEXT: retq + %sub.i = fsub <8 x double> <double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00>, %c + %res = call <8 x double> @llvm.x86.avx512.mask.vfmaddsub.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %sub.i, i8 %mask, i32 4) + ret <8 x double> %res +} +declare <8 x double> @llvm.x86.avx512.mask.vfmaddsub.pd.512(<8 x double>, <8 x double>, <8 x double>, i8, i32) |

