diff options
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 34 | ||||
| -rw-r--r-- | llvm/test/CodeGen/X86/madd.ll | 28 |
2 files changed, 36 insertions, 26 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 52ce9ec18e3..17481f7fb26 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43151,18 +43151,6 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - SDValue MulOp = N->getOperand(0); - SDValue OtherOp = N->getOperand(1); - - if (MulOp.getOpcode() != ISD::MUL) - std::swap(MulOp, OtherOp); - if (MulOp.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); - EVT VT = N->getValueType(0); // If the vector size is less than 128, or greater than the supported RegSize, @@ -43170,6 +43158,28 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!VT.isVector() || VT.getVectorNumElements() < 8) return SDValue(); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + + auto UsePMADDWD = [&](SDValue Op) { + ShrinkMode Mode; + return Op.getOpcode() == ISD::MUL && + canReduceVMulWidth(Op.getNode(), DAG, Mode) && Mode != MULU16 && + (!Subtarget.hasSSE41() || + (Op->isOnlyUserOf(Op.getOperand(0).getNode()) && + Op->isOnlyUserOf(Op.getOperand(1).getNode()))); + }; + + SDValue MulOp, OtherOp; + if (UsePMADDWD(Op0)) { + MulOp = Op0; + OtherOp = Op1; + } else if (UsePMADDWD(Op1)) { + MulOp = Op1; + OtherOp = Op0; + } else + return SDValue(); + SDLoc DL(N); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll index 2f955f7a299..8c1f90461dd 100644 --- a/llvm/test/CodeGen/X86/madd.ll +++ b/llvm/test/CodeGen/X86/madd.ll @@ -2745,20 +2745,21 @@ define i64 @sum_and_sum_of_squares(i8* %a, i32 %n) { ; AVX1-NEXT: movl %esi, %eax ; AVX1-NEXT: vpxor %xmm0, %xmm0, %xmm0 ; AVX1-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2 ; AVX1-NEXT: .p2align 4, 0x90 ; AVX1-NEXT: .LBB32_1: # %vector.body ; AVX1-NEXT: # =>This Inner Loop Header: Depth=1 -; AVX1-NEXT: vpmovzxbw {{.*#+}} xmm3 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero -; AVX1-NEXT: vpunpckhwd {{.*#+}} xmm4 = xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7] -; AVX1-NEXT: vpmovzxwd {{.*#+}} xmm5 = xmm3[0],zero,xmm3[1],zero,xmm3[2],zero,xmm3[3],zero -; AVX1-NEXT: vextractf128 $1, %ymm2, %xmm6 -; AVX1-NEXT: vpaddd %xmm6, %xmm4, %xmm4 -; AVX1-NEXT: vpaddd %xmm2, %xmm5, %xmm2 -; AVX1-NEXT: vinsertf128 $1, %xmm4, %ymm2, %ymm2 +; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero +; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero +; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm4 +; AVX1-NEXT: vpaddd %xmm4, %xmm3, %xmm4 +; AVX1-NEXT: vpaddd %xmm1, %xmm2, %xmm1 +; AVX1-NEXT: vinsertf128 $1, %xmm4, %ymm1, %ymm1 +; AVX1-NEXT: vpmaddwd %xmm2, %xmm2, %xmm2 ; AVX1-NEXT: vpmaddwd %xmm3, %xmm3, %xmm3 -; AVX1-NEXT: vpaddd %xmm0, %xmm3, %xmm3 -; AVX1-NEXT: vblendps {{.*#+}} ymm0 = ymm3[0,1,2,3],ymm0[4,5,6,7] +; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm4 +; AVX1-NEXT: vpaddd %xmm4, %xmm3, %xmm3 +; AVX1-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm0, %ymm0 ; AVX1-NEXT: addq $8, %rdi ; AVX1-NEXT: addq $-8, %rax ; AVX1-NEXT: jne .LBB32_1 @@ -2781,10 +2782,9 @@ define i64 @sum_and_sum_of_squares(i8* %a, i32 %n) { ; AVX256-NEXT: .p2align 4, 0x90 ; AVX256-NEXT: .LBB32_1: # %vector.body ; AVX256-NEXT: # =>This Inner Loop Header: Depth=1 -; AVX256-NEXT: vpmovzxbw {{.*#+}} xmm2 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero -; AVX256-NEXT: vpmovzxwd {{.*#+}} ymm3 = xmm2[0],zero,xmm2[1],zero,xmm2[2],zero,xmm2[3],zero,xmm2[4],zero,xmm2[5],zero,xmm2[6],zero,xmm2[7],zero -; AVX256-NEXT: vpaddd %ymm1, %ymm3, %ymm1 -; AVX256-NEXT: vpmaddwd %xmm2, %xmm2, %xmm2 +; AVX256-NEXT: vpmovzxbd {{.*#+}} ymm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero +; AVX256-NEXT: vpaddd %ymm1, %ymm2, %ymm1 +; AVX256-NEXT: vpmaddwd %ymm2, %ymm2, %ymm2 ; AVX256-NEXT: vpaddd %ymm0, %ymm2, %ymm0 ; AVX256-NEXT: addq $8, %rdi ; AVX256-NEXT: addq $-8, %rax |

