diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 54 |
1 files changed, 41 insertions, 13 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 50cd8ab1dbc..e3fee6fbb63 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -29144,12 +29144,18 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, if (SetCC.getOpcode() != ISD::SETCC) return false; ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); - if (CC != ISD::SETGT) + if (CC != ISD::SETGT && CC != ISD::SETLT) return false; SDValue SelectOp1 = Select->getOperand(1); SDValue SelectOp2 = Select->getOperand(2); + // The following instructions assume SelectOp1 is the subtraction operand + // and SelectOp2 is the negation operand. + // In the case of SETLT this is the other way around. + if (CC == ISD::SETLT) + std::swap(SelectOp1, SelectOp2); + // The second operand of the select should be the negation of the first // operand, which is implemented as 0 - SelectOp1. if (!(SelectOp2.getOpcode() == ISD::SUB && @@ -29162,8 +29168,17 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, if (SetCC.getOperand(0) != SelectOp1) return false; - // The second operand of the comparison can be either -1 or 0. - if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || + // In SetLT case, The second operand of the comparison can be either 1 or 0. + APInt SplatVal; + if ((CC == ISD::SETLT) && + !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) && + SplatVal == 1) || + (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode())))) + return false; + + // In SetGT case, The second operand of the comparison can be either -1 or 0. + if ((CC == ISD::SETGT) && + !(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) return false; @@ -29292,11 +29307,9 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - // Verify the type we're extracting from is appropriate - // TODO: There's nothing special about i32, any integer type above i16 should - // work just as well. + // Verify the type we're extracting from is any integer type above i16. EVT VT = Extract->getOperand(0).getValueType(); - if (!VT.isSimple() || !(VT.getVectorElementType() == MVT::i32)) + if (!VT.isSimple() || !(VT.getVectorElementType().getSizeInBits() > 16)) return SDValue(); unsigned RegSize = 128; @@ -29305,15 +29318,28 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, else if (Subtarget.hasAVX2()) RegSize = 256; - // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + // We handle upto v16i* for SSE2 / v32i* for AVX2 / v64i* for AVX512. // TODO: We should be able to handle larger vectors by splitting them before // feeding them into several SADs, and then reducing over those. - if (VT.getSizeInBits() / 4 > RegSize) + if (RegSize / VT.getVectorNumElements() < 8) return SDValue(); // Match shuffle + add pyramid. SDValue Root = matchBinOpReduction(Extract, ISD::ADD); + // The operand is expected to be zero extended from i8 + // (verified in detectZextAbsDiff). + // In order to convert to i64 and above, additional any/zero/sign + // extend is expected. + // The zero extend from 32 bit has no mathematical effect on the result. + // Also the sign extend is basically zero extend + // (extends the sign bit which is zero). + // So it is correct to skip the sign/zero extend instruction. + if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND || + Root.getOpcode() == ISD::ZERO_EXTEND || + Root.getOpcode() == ISD::ANY_EXTEND)) + Root = Root.getOperand(0); + // If there was a match, we want Root to be a select that is the root of an // abs-diff pattern. if (!Root || (Root.getOpcode() != ISD::VSELECT)) @@ -29324,7 +29350,7 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, if (!detectZextAbsDiff(Root, Zext0, Zext1)) return SDValue(); - // Create the SAD instruction + // Create the SAD instruction. SDLoc DL(Extract); SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL); @@ -29346,10 +29372,12 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, } } - // Return the lowest i32. - MVT ResVT = MVT::getVectorVT(MVT::i32, SadVT.getSizeInBits() / 32); + MVT Type = Extract->getSimpleValueType(0); + unsigned TypeSizeInBits = Type.getSizeInBits(); + // Return the lowest TypeSizeInBits bits. + MVT ResVT = MVT::getVectorVT(Type, SadVT.getSizeInBits() / TypeSizeInBits); SAD = DAG.getNode(ISD::BITCAST, DL, ResVT, SAD); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, SAD, + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Type, SAD, Extract->getOperand(1)); } |