diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 123 |
1 files changed, 64 insertions, 59 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c5e8102a89b..67e56cc33b8 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -30680,8 +30680,64 @@ static SDValue OptimizeConditionalInDecrement(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp); } -static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +// Given a select, detect the following pattern: +// 1: %2 = zext <N x i8> %0 to <N x i32> +// 2: %3 = zext <N x i8> %1 to <N x i32> +// 3: %4 = sub nsw <N x i32> %2, %3 +// 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N] +// 5: %6 = sub nsw <N x i32> zeroinitializer, %4 +// 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6 +// This is useful as it is the input into a SAD pattern. +static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, + SDValue &Op1) { + // Check the condition of the select instruction is greater-than. + SDValue SetCC = Select->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC) + return false; + ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); + if (CC != ISD::SETGT) + return false; + + SDValue SelectOp1 = Select->getOperand(1); + SDValue SelectOp2 = Select->getOperand(2); + + // The second operand of the select should be the negation of the first + // operand, which is implemented as 0 - SelectOp1. + if (!(SelectOp2.getOpcode() == ISD::SUB && + ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) && + SelectOp2.getOperand(1) == SelectOp1)) + return false; + + // The first operand of SetCC is the first operand of the select, which is the + // difference between the two input vectors. + if (SetCC.getOperand(0) != SelectOp1) + return false; + + // The second operand of the comparison can be either -1 or 0. + if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || + ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) + return false; + + // The first operand of the select is the difference between the two input + // vectors. + if (SelectOp1.getOpcode() != ISD::SUB) + return false; + + Op0 = SelectOp1.getOperand(0); + Op1 = SelectOp1.getOperand(1); + + // Check if the operands of the sub are zero-extended from vectors of i8. + if (Op0.getOpcode() != ISD::ZERO_EXTEND || + Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + Op1.getOpcode() != ISD::ZERO_EXTEND || + Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + return false; + + return true; +} + +static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDLoc DL(N); EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); @@ -30701,21 +30757,8 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, if (VT.getSizeInBits() / 4 > RegSize) return SDValue(); - // Detect the following pattern: - // - // 1: %2 = zext <N x i8> %0 to <N x i32> - // 2: %3 = zext <N x i8> %1 to <N x i32> - // 3: %4 = sub nsw <N x i32> %2, %3 - // 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N] - // 5: %6 = sub nsw <N x i32> zeroinitializer, %4 - // 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6 - // 7: %8 = add nsw <N x i32> %7, %vec.phi - // - // The last instruction must be a reduction add. The instructions 3-6 forms an - // ABSDIFF pattern. - - // The two operands of reduction add are from PHI and a select-op as in line 7 - // above. + // We know N is a reduction add, which means one of its operands is a phi. + // To match SAD, we need the other operand to be a vector select. SDValue SelectOp, Phi; if (Op0.getOpcode() == ISD::VSELECT) { SelectOp = Op0; @@ -30726,50 +30769,12 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, } else return SDValue(); - // Check the condition of the select instruction is greater-than. - SDValue SetCC = SelectOp->getOperand(0); - if (SetCC.getOpcode() != ISD::SETCC) - return SDValue(); - ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); - if (CC != ISD::SETGT) - return SDValue(); - - Op0 = SelectOp->getOperand(1); - Op1 = SelectOp->getOperand(2); - - // The second operand of SelectOp Op1 is the negation of the first operand - // Op0, which is implemented as 0 - Op0. - if (!(Op1.getOpcode() == ISD::SUB && - ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) && - Op1.getOperand(1) == Op0)) - return SDValue(); - - // The first operand of SetCC is the first operand of SelectOp, which is the - // difference between two input vectors. - if (SetCC.getOperand(0) != Op0) - return SDValue(); - - // The second operand of > comparison can be either -1 or 0. - if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || - ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) - return SDValue(); - - // The first operand of SelectOp is the difference between two input vectors. - if (Op0.getOpcode() != ISD::SUB) - return SDValue(); - - Op1 = Op0.getOperand(1); - Op0 = Op0.getOperand(0); - - // Check if the operands of the diff are zero-extended from vectors of i8. - if (Op0.getOpcode() != ISD::ZERO_EXTEND || - Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || - Op1.getOpcode() != ISD::ZERO_EXTEND || - Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + // Check whether we have an abs-diff pattern feeding into the select. + if(!detectZextAbsDiff(SelectOp, Op0, Op1)) return SDValue(); // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elments of the result of SAD is less + // reduction. Note that the number of elements of the result of SAD is less // than the number of elements of its input. Therefore, we could only update // part of elements in the reduction vector. @@ -30819,7 +30824,7 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; if (Flags->hasVectorReduction()) { - if (SDValue Sad = detectSADPattern(N, DAG, Subtarget)) + if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) return Sad; } EVT VT = N->getValueType(0); |