diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 141 | ||||
-rw-r--r-- | llvm/test/CodeGen/X86/madd.ll | 4 | ||||
-rw-r--r-- | llvm/test/CodeGen/X86/sad.ll | 30 |
3 files changed, 80 insertions, 95 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 001862bc2f9..52ce9ec18e3 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43151,8 +43151,17 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); + SDValue MulOp = N->getOperand(0); + SDValue OtherOp = N->getOperand(1); + + if (MulOp.getOpcode() != ISD::MUL) + std::swap(MulOp, OtherOp); + if (MulOp.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) + return SDValue(); EVT VT = N->getValueType(0); @@ -43161,49 +43170,33 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!VT.isVector() || VT.getVectorNumElements() < 8) return SDValue(); - if (Op0.getOpcode() != ISD::MUL) - std::swap(Op0, Op1); - if (Op0.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); - SDLoc DL(N); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, VT.getVectorNumElements() / 2); + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); + // Madd vector size is half of the original vector size auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { MVT OpVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); return DAG.getNode(X86ISD::VPMADDWD, DL, OpVT, Ops); }; - - auto BuildPMADDWD = [&](SDValue Mul) { - // Shrink the operands of mul. - SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0)); - SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1)); - - SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, - PMADDWDBuilder); - // Fill the rest of the output with 0 - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, - DAG.getConstant(0, DL, MAddVT)); - }; - - Op0 = BuildPMADDWD(Op0); - - // It's possible that Op1 is also a mul we can reduce. - if (Op1.getOpcode() == ISD::MUL && - canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) { - Op1 = BuildPMADDWD(Op1); - } - - return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); + SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, + PMADDWDBuilder); + // Fill the rest of the output with 0 + SDValue Zero = DAG.getConstant(0, DL, Madd.getSimpleValueType()); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); + + // Preserve the reduction flag on the ADD. We may need to revisit for the + // other operand. + SDNodeFlags Flags; + Flags.setVectorReduction(true); + return DAG.getNode(ISD::ADD, DL, VT, Concat, OtherOp, Flags); } static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, @@ -43213,8 +43206,6 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, SDLoc DL(N); EVT VT = N->getValueType(0); - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); // TODO: There's nothing special about i32, any integer type above i16 should // work just as well. @@ -43234,55 +43225,49 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, if (VT.getSizeInBits() / 4 > RegSize) return SDValue(); - // We know N is a reduction add, which means one of its operands is a phi. - // To match SAD, we need the other operand to be a ABS. - if (Op0.getOpcode() != ISD::ABS) - std::swap(Op0, Op1); - if (Op0.getOpcode() != ISD::ABS) - return SDValue(); - - auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) { - // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elements of the result of SAD is less - // than the number of elements of its input. Therefore, we could only update - // part of elements in the reduction vector. - SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); - - // The output of PSADBW is a vector of i64. - // We need to turn the vector of i64 into a vector of i32. - // If the reduction vector is at least as wide as the psadbw result, just - // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero - // anyway. - MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); - if (VT.getSizeInBits() >= ResVT.getSizeInBits()) - Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); - else - Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); - - if (VT.getSizeInBits() > ResVT.getSizeInBits()) { - // Fill the upper elements with zero to match the add width. - SDValue Zero = DAG.getConstant(0, DL, VT); - Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, - DAG.getIntPtrConstant(0, DL)); - } - - return Sad; - }; + // We know N is a reduction add. To match SAD, we need one of the operands to + // be an ABS. + SDValue AbsOp = N->getOperand(0); + SDValue OtherOp = N->getOperand(1); + if (AbsOp.getOpcode() != ISD::ABS) + std::swap(AbsOp, OtherOp); + if (AbsOp.getOpcode() != ISD::ABS) + return SDValue(); // Check whether we have an abs-diff pattern feeding into the select. SDValue SadOp0, SadOp1; - if (!detectZextAbsDiff(Op0, SadOp0, SadOp1)) - return SDValue(); - - Op0 = BuildPSADBW(SadOp0, SadOp1); + if(!detectZextAbsDiff(AbsOp, SadOp0, SadOp1)) + return SDValue(); + + // SAD pattern detected. Now build a SAD instruction and an addition for + // reduction. Note that the number of elements of the result of SAD is less + // than the number of elements of its input. Therefore, we could only update + // part of elements in the reduction vector. + SDValue Sad = createPSADBW(DAG, SadOp0, SadOp1, DL, Subtarget); + + // The output of PSADBW is a vector of i64. + // We need to turn the vector of i64 into a vector of i32. + // If the reduction vector is at least as wide as the psadbw result, just + // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero + // anyway. + MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); + if (VT.getSizeInBits() >= ResVT.getSizeInBits()) + Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); + else + Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); - // It's possible we have a sad on the other side too. - if (Op1.getOpcode() == ISD::ABS && - detectZextAbsDiff(Op1, SadOp0, SadOp1)) { - Op1 = BuildPSADBW(SadOp0, SadOp1); + if (VT.getSizeInBits() > ResVT.getSizeInBits()) { + // Fill the upper elements with zero to match the add width. + SDValue Zero = DAG.getConstant(0, DL, VT); + Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, + DAG.getIntPtrConstant(0, DL)); } - return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); + // Preserve the reduction flag on the ADD. We may need to revisit for the + // other operand. + SDNodeFlags Flags; + Flags.setVectorReduction(true); + return DAG.getNode(ISD::ADD, DL, VT, Sad, OtherOp, Flags); } /// Convert vector increment or decrement to sub/add with an all-ones constant: diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll index 41f841a8ed9..ecf43f5f65d 100644 --- a/llvm/test/CodeGen/X86/madd.ll +++ b/llvm/test/CodeGen/X86/madd.ll @@ -2677,9 +2677,9 @@ define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>* ; AVX: # %bb.0: ; AVX-NEXT: vmovdqu (%rdi), %xmm0 ; AVX-NEXT: vmovdqu (%rdx), %xmm1 -; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 -; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 +; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] ; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] diff --git a/llvm/test/CodeGen/X86/sad.ll b/llvm/test/CodeGen/X86/sad.ll index 0bbae5c85e9..5b1a0f8cead 100644 --- a/llvm/test/CodeGen/X86/sad.ll +++ b/llvm/test/CodeGen/X86/sad.ll @@ -1403,18 +1403,18 @@ define i32 @sad_unroll_nonzero_initial(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x ; SSE2-NEXT: movdqu (%rdi), %xmm0 ; SSE2-NEXT: movdqu (%rsi), %xmm1 ; SSE2-NEXT: psadbw %xmm0, %xmm1 -; SSE2-NEXT: movdqu (%rdx), %xmm0 -; SSE2-NEXT: movdqu (%rcx), %xmm2 -; SSE2-NEXT: psadbw %xmm0, %xmm2 ; SSE2-NEXT: movl $1, %eax ; SSE2-NEXT: movd %eax, %xmm0 -; SSE2-NEXT: paddd %xmm2, %xmm0 -; SSE2-NEXT: paddd %xmm1, %xmm0 -; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] +; SSE2-NEXT: movdqu (%rdx), %xmm2 +; SSE2-NEXT: movdqu (%rcx), %xmm3 +; SSE2-NEXT: psadbw %xmm2, %xmm3 +; SSE2-NEXT: paddd %xmm0, %xmm3 +; SSE2-NEXT: paddd %xmm1, %xmm3 +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm3[2,3,0,1] +; SSE2-NEXT: paddd %xmm3, %xmm0 +; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] ; SSE2-NEXT: paddd %xmm0, %xmm1 -; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,1,2,3] -; SSE2-NEXT: paddd %xmm1, %xmm0 -; SSE2-NEXT: movd %xmm0, %eax +; SSE2-NEXT: movd %xmm1, %eax ; SSE2-NEXT: retq ; ; AVX1-LABEL: sad_unroll_nonzero_initial: @@ -1442,7 +1442,7 @@ define i32 @sad_unroll_nonzero_initial(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x ; AVX2-NEXT: vmovd %eax, %xmm1 ; AVX2-NEXT: vmovdqu (%rdx), %xmm2 ; AVX2-NEXT: vpsadbw (%rcx), %xmm2, %xmm2 -; AVX2-NEXT: vpaddd %xmm1, %xmm2, %xmm1 +; AVX2-NEXT: vpaddd %xmm2, %xmm1, %xmm1 ; AVX2-NEXT: vpaddd %xmm1, %xmm0, %xmm0 ; AVX2-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] ; AVX2-NEXT: vpaddd %xmm1, %xmm0, %xmm0 @@ -1455,10 +1455,10 @@ define i32 @sad_unroll_nonzero_initial(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x ; AVX512: # %bb.0: # %bb ; AVX512-NEXT: vmovdqu (%rdi), %xmm0 ; AVX512-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 -; AVX512-NEXT: vmovdqu (%rdx), %xmm1 -; AVX512-NEXT: vpsadbw (%rcx), %xmm1, %xmm1 ; AVX512-NEXT: movl $1, %eax -; AVX512-NEXT: vmovd %eax, %xmm2 +; AVX512-NEXT: vmovd %eax, %xmm1 +; AVX512-NEXT: vmovdqu (%rdx), %xmm2 +; AVX512-NEXT: vpsadbw (%rcx), %xmm2, %xmm2 ; AVX512-NEXT: vpaddd %zmm2, %zmm1, %zmm1 ; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 ; AVX512-NEXT: vextracti64x4 $1, %zmm0, %ymm1 @@ -1526,9 +1526,9 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* % ; AVX: # %bb.0: # %bb ; AVX-NEXT: vmovdqu (%rdi), %xmm0 ; AVX-NEXT: vmovdqu (%rdx), %xmm1 -; AVX-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vpsadbw (%rcx), %xmm1, %xmm1 -; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 +; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] ; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] |