diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 141 |
1 files changed, 63 insertions, 78 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 001862bc2f9..52ce9ec18e3 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43151,8 +43151,17 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); + SDValue MulOp = N->getOperand(0); + SDValue OtherOp = N->getOperand(1); + + if (MulOp.getOpcode() != ISD::MUL) + std::swap(MulOp, OtherOp); + if (MulOp.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16) + return SDValue(); EVT VT = N->getValueType(0); @@ -43161,49 +43170,33 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, if (!VT.isVector() || VT.getVectorNumElements() < 8) return SDValue(); - if (Op0.getOpcode() != ISD::MUL) - std::swap(Op0, Op1); - if (Op0.getOpcode() != ISD::MUL) - return SDValue(); - - ShrinkMode Mode; - if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16) - return SDValue(); - SDLoc DL(N); EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, VT.getVectorNumElements() / 2); + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); + // Madd vector size is half of the original vector size auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { MVT OpVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); return DAG.getNode(X86ISD::VPMADDWD, DL, OpVT, Ops); }; - - auto BuildPMADDWD = [&](SDValue Mul) { - // Shrink the operands of mul. - SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0)); - SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1)); - - SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, - PMADDWDBuilder); - // Fill the rest of the output with 0 - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, - DAG.getConstant(0, DL, MAddVT)); - }; - - Op0 = BuildPMADDWD(Op0); - - // It's possible that Op1 is also a mul we can reduce. - if (Op1.getOpcode() == ISD::MUL && - canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) { - Op1 = BuildPMADDWD(Op1); - } - - return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); + SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 }, + PMADDWDBuilder); + // Fill the rest of the output with 0 + SDValue Zero = DAG.getConstant(0, DL, Madd.getSimpleValueType()); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); + + // Preserve the reduction flag on the ADD. We may need to revisit for the + // other operand. + SDNodeFlags Flags; + Flags.setVectorReduction(true); + return DAG.getNode(ISD::ADD, DL, VT, Concat, OtherOp, Flags); } static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, @@ -43213,8 +43206,6 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, SDLoc DL(N); EVT VT = N->getValueType(0); - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); // TODO: There's nothing special about i32, any integer type above i16 should // work just as well. @@ -43234,55 +43225,49 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, if (VT.getSizeInBits() / 4 > RegSize) return SDValue(); - // We know N is a reduction add, which means one of its operands is a phi. - // To match SAD, we need the other operand to be a ABS. - if (Op0.getOpcode() != ISD::ABS) - std::swap(Op0, Op1); - if (Op0.getOpcode() != ISD::ABS) - return SDValue(); - - auto BuildPSADBW = [&](SDValue Op0, SDValue Op1) { - // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elements of the result of SAD is less - // than the number of elements of its input. Therefore, we could only update - // part of elements in the reduction vector. - SDValue Sad = createPSADBW(DAG, Op0, Op1, DL, Subtarget); - - // The output of PSADBW is a vector of i64. - // We need to turn the vector of i64 into a vector of i32. - // If the reduction vector is at least as wide as the psadbw result, just - // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero - // anyway. - MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); - if (VT.getSizeInBits() >= ResVT.getSizeInBits()) - Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); - else - Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); - - if (VT.getSizeInBits() > ResVT.getSizeInBits()) { - // Fill the upper elements with zero to match the add width. - SDValue Zero = DAG.getConstant(0, DL, VT); - Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, - DAG.getIntPtrConstant(0, DL)); - } - - return Sad; - }; + // We know N is a reduction add. To match SAD, we need one of the operands to + // be an ABS. + SDValue AbsOp = N->getOperand(0); + SDValue OtherOp = N->getOperand(1); + if (AbsOp.getOpcode() != ISD::ABS) + std::swap(AbsOp, OtherOp); + if (AbsOp.getOpcode() != ISD::ABS) + return SDValue(); // Check whether we have an abs-diff pattern feeding into the select. SDValue SadOp0, SadOp1; - if (!detectZextAbsDiff(Op0, SadOp0, SadOp1)) - return SDValue(); - - Op0 = BuildPSADBW(SadOp0, SadOp1); + if(!detectZextAbsDiff(AbsOp, SadOp0, SadOp1)) + return SDValue(); + + // SAD pattern detected. Now build a SAD instruction and an addition for + // reduction. Note that the number of elements of the result of SAD is less + // than the number of elements of its input. Therefore, we could only update + // part of elements in the reduction vector. + SDValue Sad = createPSADBW(DAG, SadOp0, SadOp1, DL, Subtarget); + + // The output of PSADBW is a vector of i64. + // We need to turn the vector of i64 into a vector of i32. + // If the reduction vector is at least as wide as the psadbw result, just + // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero + // anyway. + MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); + if (VT.getSizeInBits() >= ResVT.getSizeInBits()) + Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); + else + Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); - // It's possible we have a sad on the other side too. - if (Op1.getOpcode() == ISD::ABS && - detectZextAbsDiff(Op1, SadOp0, SadOp1)) { - Op1 = BuildPSADBW(SadOp0, SadOp1); + if (VT.getSizeInBits() > ResVT.getSizeInBits()) { + // Fill the upper elements with zero to match the add width. + SDValue Zero = DAG.getConstant(0, DL, VT); + Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, + DAG.getIntPtrConstant(0, DL)); } - return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1); + // Preserve the reduction flag on the ADD. We may need to revisit for the + // other operand. + SDNodeFlags Flags; + Flags.setVectorReduction(true); + return DAG.getNode(ISD::ADD, DL, VT, Sad, OtherOp, Flags); } /// Convert vector increment or decrement to sub/add with an all-ones constant: |