diff options
Diffstat (limited to 'llvm/lib/CodeGen')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index b2de2a4f343..4b21b96c9df 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2300,6 +2300,52 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) { (AllowUndefs || !UndefElts); } +SDValue SelectionDAG::getSplatSourceVector(SDValue V, int &SplatIdx) { + V = peekThroughExtractSubvectors(V); + + EVT VT = V.getValueType(); + unsigned Opcode = V.getOpcode(); + switch (Opcode) { + default: { + APInt UndefElts; + APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (isSplatValue(V, DemandedElts, UndefElts)) { + // Handle case where all demanded elements are UNDEF. + if (DemandedElts.isSubsetOf(UndefElts)) { + SplatIdx = 0; + return getUNDEF(VT); + } + SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); + return V; + } + break; + } + case ISD::VECTOR_SHUFFLE: { + // Check if this is a shuffle node doing a splat. + // TODO - remove this and rely purely on SelectionDAG::isSplatValue, + // getTargetVShiftNode currently struggles without the splat source. + auto *SVN = cast<ShuffleVectorSDNode>(V); + if (!SVN->isSplat()) + break; + int Idx = SVN->getSplatIndex(); + int NumElts = V.getValueType().getVectorNumElements(); + SplatIdx = Idx % NumElts; + return V.getOperand(Idx / NumElts); + } + } + + return SDValue(); +} + +SDValue SelectionDAG::getSplatValue(SDValue V) { + int SplatIdx; + if (SDValue SrcVector = getSplatSourceVector(V, SplatIdx)) + return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), + SrcVector.getValueType().getScalarType(), SrcVector, + getIntPtrConstant(SplatIdx, SDLoc(V))); + return SDValue(); +} + /// If a SHL/SRA/SRL node has a constant or splat constant shift amount that /// is less than the element bit-width of the shift node, return it. static const APInt *getValidShiftAmountConstant(SDValue V) { @@ -8585,6 +8631,12 @@ SDValue llvm::peekThroughOneUseBitcasts(SDValue V) { return V; } +SDValue llvm::peekThroughExtractSubvectors(SDValue V) { + while (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) + V = V.getOperand(0); + return V; +} + bool llvm::isBitwiseNot(SDValue V) { if (V.getOpcode() != ISD::XOR) return false; |

