summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/X86
diff options
context:
space:
mode:
authorSimon Pilgrim <llvm-dev@redking.me.uk>2018-10-12 14:18:47 +0000
committerSimon Pilgrim <llvm-dev@redking.me.uk>2018-10-12 14:18:47 +0000
commit78b5a3c3ef120e51e31a592ec98b2f0558f2f284 (patch)
treec2bec427e9a5038ceef5680db580c4620c636110 /llvm/lib/Target/X86
parent9552dd187aadd92aeacda13ad4294be12ebe85ab (diff)
downloadbcm5719-llvm-78b5a3c3ef120e51e31a592ec98b2f0558f2f284.tar.gz
bcm5719-llvm-78b5a3c3ef120e51e31a592ec98b2f0558f2f284.zip
[X86][SSE] LowerVectorCTPOP - pull out repeated byte sum stage.
Pull out repeated byte sum stage for popcount of vector elements > 8bits. This allows us to simplify the LUT/BITMATH popcnt code to always assume vXi8 vectors, and also improves avx512bitalg codegen which only has access to vpopcntb/vpopcntw. llvm-svn: 344348
Diffstat (limited to 'llvm/lib/Target/X86')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp81
1 files changed, 30 insertions, 51 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 15bd238833d..d2971d0f861 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -25023,7 +25023,8 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
MVT EltVT = VT.getVectorElementType();
- unsigned VecSize = VT.getSizeInBits();
+ int NumElts = VT.getVectorNumElements();
+ assert(EltVT == MVT::i8 && "Only vXi8 vector CTPOP lowering supported.");
// Implement a lookup table in register by using an algorithm based on:
// http://wm.ite.pl/articles/sse-popcount.html
@@ -25035,56 +25036,37 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL,
// masked out higher ones) for each byte. PSHUFB is used separately with both
// to index the in-register table. Next, both are added and the result is a
// i8 vector where each element contains the pop count for input byte.
- //
- // To obtain the pop count for elements != i8, we follow up with the same
- // approach and use additional tricks as described below.
- //
const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4};
- int NumByteElts = VecSize / 8;
- MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts);
- SDValue In = DAG.getBitcast(ByteVecVT, Op);
SmallVector<SDValue, 64> LUTVec;
- for (int i = 0; i < NumByteElts; ++i)
+ for (int i = 0; i < NumElts; ++i)
LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8));
- SDValue InRegLUT = DAG.getBuildVector(ByteVecVT, DL, LUTVec);
- SDValue M0F = DAG.getConstant(0x0F, DL, ByteVecVT);
+ SDValue InRegLUT = DAG.getBuildVector(VT, DL, LUTVec);
+ SDValue M0F = DAG.getConstant(0x0F, DL, VT);
// High nibbles
- SDValue FourV = DAG.getConstant(4, DL, ByteVecVT);
- SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV);
+ SDValue FourV = DAG.getConstant(4, DL, VT);
+ SDValue HiNibbles = DAG.getNode(ISD::SRL, DL, VT, Op, FourV);
// Low nibbles
- SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F);
+ SDValue LoNibbles = DAG.getNode(ISD::AND, DL, VT, Op, M0F);
// The input vector is used as the shuffle mask that index elements into the
// LUT. After counting low and high nibbles, add the vector to obtain the
// final pop count per i8 element.
- SDValue HighPopCnt =
- DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles);
- SDValue LowPopCnt =
- DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles);
- SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt);
-
- if (EltVT == MVT::i8)
- return PopCnt;
-
- return LowerHorizontalByteSum(PopCnt, VT, Subtarget, DAG);
+ SDValue HiPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, HiNibbles);
+ SDValue LoPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, LoNibbles);
+ return DAG.getNode(ISD::ADD, DL, VT, HiPopCnt, LoPopCnt);
}
static SDValue LowerVectorCTPOPBitmath(SDValue Op, const SDLoc &DL,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
- assert(VT.is128BitVector() &&
- "Only 128-bit vector bitmath lowering supported.");
-
- int VecSize = VT.getSizeInBits();
- MVT EltVT = VT.getVectorElementType();
- int Len = EltVT.getSizeInBits();
+ assert(VT == MVT::v16i8 && "Only v16i8 vector CTPOP lowering supported.");
// This is the vectorized version of the "best" algorithm from
// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
@@ -25108,36 +25090,27 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, const SDLoc &DL,
// x86, so set the SRL type to have elements at least i16 wide. This is
// correct because all of our SRLs are followed immediately by a mask anyways
// that handles any bits that sneak into the high bits of the byte elements.
- MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16);
-
+ MVT SrlVT = MVT::v8i16;
SDValue V = Op;
// v = v - ((v >> 1) & 0x55555555...)
SDValue Srl =
DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 1));
- SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55)));
+ SDValue And = GetMask(Srl, APInt(8, 0x55));
V = DAG.getNode(ISD::SUB, DL, VT, V, And);
// v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
- SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33)));
+ SDValue AndLHS = GetMask(V, APInt(8, 0x33));
Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 2));
- SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33)));
+ SDValue AndRHS = GetMask(Srl, APInt(8, 0x33));
V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
// v = (v + (v >> 4)) & 0x0F0F0F0F...
Srl = DAG.getBitcast(VT, GetShift(ISD::SRL, DAG.getBitcast(SrlVT, V), 4));
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
- V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F)));
+ V = GetMask(Add, APInt(8, 0x0F));
- // At this point, V contains the byte-wise population count, and we are
- // merely doing a horizontal sum if necessary to get the wider element
- // counts.
- if (EltVT == MVT::i8)
- return V;
-
- return LowerHorizontalByteSum(
- DAG.getBitcast(MVT::getVectorVT(MVT::i8, VecSize / 8), V), VT, Subtarget,
- DAG);
+ return V;
}
// Please ensure that any codegen change from LowerVectorCTPOP is reflected in
@@ -25163,12 +25136,6 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget,
}
}
- if (!Subtarget.hasSSSE3()) {
- // We can't use the fast LUT approach, so fall back on vectorized bitmath.
- assert(VT.is128BitVector() && "Only 128-bit vectors supported in SSE!");
- return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
- }
-
// Decompose 256-bit ops into smaller 128-bit ops.
if (VT.is256BitVector() && !Subtarget.hasInt256())
return Lower256IntUnary(Op, DAG);
@@ -25177,6 +25144,18 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget,
if (VT.is512BitVector() && !Subtarget.hasBWI())
return Lower512IntUnary(Op, DAG);
+ // For element types greater than i8, do vXi8 pop counts and a bytesum.
+ if (VT.getScalarType() != MVT::i8) {
+ MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8);
+ SDValue ByteOp = DAG.getBitcast(ByteVT, Op0);
+ SDValue PopCnt8 = DAG.getNode(ISD::CTPOP, DL, ByteVT, ByteOp);
+ return LowerHorizontalByteSum(PopCnt8, VT, Subtarget, DAG);
+ }
+
+ // We can't use the fast LUT approach, so fall back on vectorized bitmath.
+ if (!Subtarget.hasSSSE3())
+ return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
+
return LowerVectorCTPOPInRegLUT(Op0, DL, Subtarget, DAG);
}
OpenPOWER on IntegriCloud