diff options
author | Michael Kuperstein <mkuper@google.com> | 2016-07-29 21:45:51 +0000 |
---|---|---|
committer | Michael Kuperstein <mkuper@google.com> | 2016-07-29 21:45:51 +0000 |
commit | f396b4c40dad7df67f7cadc6aaa7ac3dbf42f866 (patch) | |
tree | 41d0f8a47ae92af93d0f516baf1b311c630678dd /llvm/lib | |
parent | e9ac9b9aaf51c49e54f9f3664cb6b5e28cae0c39 (diff) | |
download | bcm5719-llvm-f396b4c40dad7df67f7cadc6aaa7ac3dbf42f866.tar.gz bcm5719-llvm-f396b4c40dad7df67f7cadc6aaa7ac3dbf42f866.zip |
[X86] Match PSADBW in straight-line code
Up until now, we only had code to match PSADBW patterns that look like what
comes out of the loop vectorizer - a partial reduction inside the loop body
that gets fed into a horizontal operation in a different basic block.
This adds support for straight-line patterns, like those generated by the
SLP vectorizer.
Differential Revision: https://reviews.llvm.org/D22889
llvm-svn: 277219
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 137 |
1 files changed, 135 insertions, 2 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 01389e77562..d28a661e9ad 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -26278,6 +26278,58 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, return SDValue(); } + +// Match a binop + shuffle pyramid that represents a horizontal reduction over +// the elements of a vector. +// Returns the vector that is being reduced on, or SDValue() if a reduction +// was not matched. +static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { + // The pattern must end in an extract from index 0. + if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || + !isNullConstant(Extract->getOperand(1))) + return SDValue(); + + unsigned Stages = + Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements()); + + SDValue Op = Extract->getOperand(0); + // At each stage, we're looking for something that looks like: + // %s = shufflevector <8 x i32> %op, <8 x i32> undef, + // <8 x i32> <i32 2, i32 3, i32 undef, i32 undef, + // i32 undef, i32 undef, i32 undef, i32 undef> + // %a = binop <8 x i32> %op, %s + // Where the mask changes according to the stage. E.g. for a 3-stage pyramid, + // we expect something like: + // <4,5,6,7,u,u,u,u> + // <2,3,u,u,u,u,u,u> + // <1,u,u,u,u,u,u,u> + for (unsigned i = 0; i < Stages; ++i) { + if (Op.getOpcode() != BinOp) + return SDValue(); + + ShuffleVectorSDNode *Shuffle = + dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0).getNode()); + if (Shuffle) { + Op = Op.getOperand(1); + } else { + Shuffle = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1).getNode()); + Op = Op.getOperand(0); + } + + // The first operand of the shuffle should be the same as the other operand + // of the add. + if (!Shuffle || (Shuffle->getOperand(0) != Op)) + return SDValue(); + + // Verify the shuffle has the expected (at this stage of the pyramid) mask. + for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) + if (Shuffle->getMaskElt(Index) != MaskEnd + Index) + return SDValue(); + } + + return Op; +} + // Given a select, detect the following pattern: // 1: %2 = zext <N x i8> %0 to <N x i32> // 2: %3 = zext <N x i8> %1 to <N x i32> @@ -26358,12 +26410,81 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); } +static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // PSADBW is only supported on SSE2 and up. + if (!Subtarget.hasSSE2()) + return SDValue(); + + // Verify the type we're extracting from is appropriate + // TODO: There's nothing special about i32, any integer type above i16 should + // work just as well. + EVT VT = Extract->getOperand(0).getValueType(); + if (!VT.isSimple() || !(VT.getVectorElementType() == MVT::i32)) + return SDValue(); + + unsigned RegSize = 128; + if (Subtarget.hasBWI()) + RegSize = 512; + else if (Subtarget.hasAVX2()) + RegSize = 256; + + // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + // TODO: We should be able to handle larger vectors by splitting them before + // feeding them into several SADs, and then reducing over those. + if (VT.getSizeInBits() / 4 > RegSize) + return SDValue(); + + // Match shuffle + add pyramid. + SDValue Root = matchBinOpReduction(Extract, ISD::ADD); + + // If there was a match, we want Root to be a select that is the root of an + // abs-diff pattern. + if (!Root || (Root.getOpcode() != ISD::VSELECT)) + return SDValue(); + + // Check whether we have an abs-diff pattern feeding into the select. + SDValue Zext0, Zext1; + if (!detectZextAbsDiff(Root, Zext0, Zext1)) + return SDValue(); + + // Create the SAD instruction + SDLoc DL(Extract); + SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL); + + // If the original vector was wider than 8 elements, sum over the results + // in the SAD vector. + unsigned Stages = Log2_32(VT.getVectorNumElements()); + MVT SadVT = SAD.getSimpleValueType(); + if (Stages > 3) { + unsigned SadElems = SadVT.getVectorNumElements(); + + for(unsigned i = Stages - 3; i > 0; --i) { + SmallVector<int, 16> Mask(SadElems, -1); + for(unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j) + Mask[j] = MaskEnd + j; + + SDValue Shuffle = + DAG.getVectorShuffle(SadVT, DL, SAD, DAG.getUNDEF(SadVT), Mask); + SAD = DAG.getNode(ISD::ADD, DL, SadVT, SAD, Shuffle); + } + } + + + // Return the lowest i32. + MVT ResVT = MVT::getVectorVT(MVT::i32, SadVT.getSizeInBits() / 32); + SAD = DAG.getNode(ISD::BITCAST, DL, ResVT, SAD); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, SAD, + Extract->getOperand(1)); +} + /// Detect vector gather/scatter index generation and convert it from being a /// bunch of shuffles and extracts into a somewhat faster sequence. /// For i686, the best sequence is apparently storing the value and loading /// scalars back, while for x64 we should use 64-bit extracts and shifts. static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { if (SDValue NewOp = XFormVExtractWithShuffleIntoLoad(N, DAG, DCI)) return NewOp; @@ -26394,6 +26515,13 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, uint64_t Res = (InputValue >> ExtractedElt) & 1; return DAG.getConstant(Res, dl, MVT::i1); } + + // Check whether this extract is the root of a sum of absolute differences + // pattern. This has to be done here because we really want it to happen + // pre-legalization, + if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) + return SAD; + // Only operate on vectors of 4 elements, where the alternative shuffling // gets to be more expensive. if (InputVector.getValueType() != MVT::v4i32) @@ -30730,6 +30858,8 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); + // TODO: There's nothing special about i32, any integer type above i16 should + // work just as well. if (!VT.isVector() || !VT.isSimple() || !(VT.getVectorElementType() == MVT::i32)) return SDValue(); @@ -30741,6 +30871,8 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, RegSize = 256; // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + // TODO: We should be able to handle larger vectors by splitting them before + // feeding them into several SADs, and then reducing over those. if (VT.getSizeInBits() / 4 > RegSize) return SDValue(); @@ -30978,7 +31110,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, SelectionDAG &DAG = DCI.DAG; switch (N->getOpcode()) { default: break; - case ISD::EXTRACT_VECTOR_ELT: return combineExtractVectorElt(N, DAG, DCI); + case ISD::EXTRACT_VECTOR_ELT: + return combineExtractVectorElt(N, DAG, DCI, Subtarget); case ISD::VSELECT: case ISD::SELECT: case X86ISD::SHRUNKBLEND: return combineSelect(N, DAG, DCI, Subtarget); |