diff options
| author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2018-10-12 14:18:47 +0000 |
|---|---|---|
| committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2018-10-12 14:18:47 +0000 |
| commit | 78b5a3c3ef120e51e31a592ec98b2f0558f2f284 (patch) | |
| tree | c2bec427e9a5038ceef5680db580c4620c636110 /llvm/lib/Target/X86 | |
| parent | 9552dd187aadd92aeacda13ad4294be12ebe85ab (diff) | |
| download | bcm5719-llvm-78b5a3c3ef120e51e31a592ec98b2f0558f2f284.tar.gz bcm5719-llvm-78b5a3c3ef120e51e31a592ec98b2f0558f2f284.zip | |
[X86][SSE] LowerVectorCTPOP - pull out repeated byte sum stage.
Pull out repeated byte sum stage for popcount of vector elements > 8bits.
This allows us to simplify the LUT/BITMATH popcnt code to always assume vXi8 vectors, and also improves avx512bitalg codegen which only has access to vpopcntb/vpopcntw.
llvm-svn: 344348
Diffstat (limited to 'llvm/lib/Target/X86')
| -rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 81 |
1 files changed, 30 insertions, 51 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 15bd238833d..d2971d0f861 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -25023,7 +25023,8 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); MVT EltVT = VT.getVectorElementType(); - unsigned VecSize = VT.getSizeInBits(); + int NumElts = VT.getVectorNumElements(); + assert(EltVT == MVT::i8 && "Only vXi8 vector CTPOP lowering supported."); // Implement a lookup table in register by using an algorithm based on: // http://wm.ite.pl/articles/sse-popcount.html @@ -25035,56 +25036,37 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL, // masked out higher ones) for each byte. PSHUFB is used separately with both // to index the in-register table. Next, both are added and the result is a // i8 vector where each element contains the pop count for input byte. - // - // To obtain the pop count for elements != i8, we follow up with the same - // approach and use additional tricks as described below. - // const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2, /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3, /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3, /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4}; - int NumByteElts = VecSize / 8; - MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts); - SDValue In = DAG.getBitcast(ByteVecVT, Op); SmallVector<SDValue, 64> LUTVec; - for (int i = 0; i < NumByteElts; ++i) + for (int i = 0; i < NumElts; ++i) LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8)); - SDValue InRegLUT = DAG.getBuildVector(ByteVecVT, DL, LUTVec); - SDValue M0F = DAG.getConstant(0x0F, DL, ByteVecVT); + SDValue InRegLUT = DAG.getBuildVector(VT, DL, LUTVec); + SDValue M0F = DAG.getConstant(0x0F, DL, VT); // High nibbles - SDValue FourV = DAG.getConstant(4, DL, ByteVecVT); - SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV); + SDValue FourV = DAG.getConstant(4, DL, VT); + SDValue HiNibbles = DAG.getNode(ISD::SRL, DL, VT, Op, FourV); // Low nibbles - SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F); + SDValue LoNibbles = DAG.getNode(ISD::AND, DL, VT, Op, M0F); // The input vector is used as the shuffle mask that index elements into the // LUT. After counting low and high nibbles, add the vector to obtain the // final pop count per i8 element. - SDValue HighPopCnt = - DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles); - SDValue LowPopCnt = - DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles); - SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt); - - if (EltVT == MVT::i8) - return PopCnt; - - return LowerHorizontalByteSum(PopCnt, VT, Subtarget, DAG); + SDValue HiPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, HiNibbles); + SDValue LoPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, LoNibbles); + return DAG.getNode(ISD::ADD, DL, VT, HiPopCnt, LoPopCnt); } static SDValue LowerVectorCTPOPBitmath(SDValue Op, const SDLoc &DL, const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); - assert(VT.is128BitVector() && - "Only 128-bit vector bitmath lowering supported."); - - int VecSize = VT.getSizeInBits(); - MVT EltVT = VT.getVectorElementType(); - int Len = EltVT.getSizeInBits(); + assert(VT == MVT::v16i8 && "Only v16i8 vector CTPOP lowering supported."); // This is the vectorized version of the "best" algorithm from // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel @@ -25108,36 +25090,27 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, const SDLoc &DL, // x86, so set the SRL type to have elements at least i16 wide. This is // correct because all of our SRLs are followed immediately by a mask anyways // that handles any bits that sneak into the high bits of the byte elements. - MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16); - + MVT SrlVT = MVT::v8i16; SDValue V = Op; // v = v - ((v >> 1) & 0x55555555...) SDValue Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 1)); - SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55))); + SDValue And = GetMask(Srl, APInt(8, 0x55)); V = DAG.getNode(ISD::SUB, DL, VT, V, And); // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...) - SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33))); + SDValue AndLHS = GetMask(V, APInt(8, 0x33)); Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 2)); - SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33))); + SDValue AndRHS = GetMask(Srl, APInt(8, 0x33)); V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS); // v = (v + (v >> 4)) & 0x0F0F0F0F... Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 4)); SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl); - V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F))); + V = GetMask(Add, APInt(8, 0x0F)); - // At this point, V contains the byte-wise population count, and we are - // merely doing a horizontal sum if necessary to get the wider element - // counts. - if (EltVT == MVT::i8) - return V; - - return LowerHorizontalByteSum( - DAG.getBitcast(MVT::getVectorVT(MVT::i8, VecSize / 8), V), VT, Subtarget, - DAG); + return V; } // Please ensure that any codegen change from LowerVectorCTPOP is reflected in @@ -25163,12 +25136,6 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, } } - if (!Subtarget.hasSSSE3()) { - // We can't use the fast LUT approach, so fall back on vectorized bitmath. - assert(VT.is128BitVector() && "Only 128-bit vectors supported in SSE!"); - return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG); - } - // Decompose 256-bit ops into smaller 128-bit ops. if (VT.is256BitVector() && !Subtarget.hasInt256()) return Lower256IntUnary(Op, DAG); @@ -25177,6 +25144,18 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, if (VT.is512BitVector() && !Subtarget.hasBWI()) return Lower512IntUnary(Op, DAG); + // For element types greater than i8, do vXi8 pop counts and a bytesum. + if (VT.getScalarType() != MVT::i8) { + MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8); + SDValue ByteOp = DAG.getBitcast(ByteVT, Op0); + SDValue PopCnt8 = DAG.getNode(ISD::CTPOP, DL, ByteVT, ByteOp); + return LowerHorizontalByteSum(PopCnt8, VT, Subtarget, DAG); + } + + // We can't use the fast LUT approach, so fall back on vectorized bitmath. + if (!Subtarget.hasSSSE3()) + return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG); + return LowerVectorCTPOPInRegLUT(Op0, DL, Subtarget, DAG); } |

