From eb508f8ccb21caa7c9d0f6c2479d73ac79d6b25c Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Wed, 12 Dec 2018 18:32:29 +0000 Subject: [SelectionDAG] Add a generic isSplatValue function This patch introduces a generic function to determine whether a given vector type is known to be a splat value for the specified demanded elements, recursing up the DAG looking for BUILD_VECTOR or VECTOR_SHUFFLE splat patterns. It also keeps track of the elements that are known to be UNDEF - it returns true if all the demanded elements are UNDEF (as this may be useful under some circumstances), so this needs to be handled by the caller. A wrapper variant is also provided that doesn't take the DemandedElts or UndefElts arguments for cases where we just want to know if the SDValue is a splat or not (with/without UNDEFS). I had hoped to completely remove the X86 local version of this function, but I'm seeing some regressions in shift/rotate codegen that will take a little longer to fix and I hope to get this in sooner so I can continue work on PR38243 which needs more capable splat detection. Differential Revision: https://reviews.llvm.org/D55426 llvm-svn: 348953 --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 96 ++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) (limited to 'llvm/lib/CodeGen/SelectionDAG') diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 01364944b22..62bb94ccba2 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2121,6 +2121,102 @@ bool SelectionDAG::MaskedValueIsZero(SDValue Op, const APInt &Mask, return Mask.isSubsetOf(computeKnownBits(Op, Depth).Zero); } +/// isSplatValue - Return true if the vector V has the same value +/// across all DemandedElts. +bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, + APInt &UndefElts) { + if (!DemandedElts) + return false; // No demanded elts, better to assume we don't know anything. + + EVT VT = V.getValueType(); + assert(VT.isVector() && "Vector type expected"); + + unsigned NumElts = VT.getVectorNumElements(); + assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch"); + UndefElts = APInt::getNullValue(NumElts); + + switch (V.getOpcode()) { + case ISD::BUILD_VECTOR: { + SDValue Scl; + for (unsigned i = 0; i != NumElts; ++i) { + SDValue Op = V.getOperand(i); + if (Op.isUndef()) { + UndefElts.setBit(i); + continue; + } + if (!DemandedElts[i]) + continue; + if (Scl && Scl != Op) + return false; + Scl = Op; + } + return true; + } + case ISD::VECTOR_SHUFFLE: { + // Check if this is a shuffle node doing a splat. + // TODO: Do we need to handle shuffle(splat, undef, mask)? + int SplatIndex = -1; + ArrayRef Mask = cast(V)->getMask(); + for (int i = 0; i != (int)NumElts; ++i) { + int M = Mask[i]; + if (M < 0) { + UndefElts.setBit(i); + continue; + } + if (!DemandedElts[i]) + continue; + if (0 <= SplatIndex && SplatIndex != M) + return false; + SplatIndex = M; + } + return true; + } + case ISD::EXTRACT_SUBVECTOR: { + SDValue Src = V.getOperand(0); + ConstantSDNode *SubIdx = dyn_cast(V.getOperand(1)); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); + if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { + // Offset the demanded elts by the subvector index. + uint64_t Idx = SubIdx->getZExtValue(); + APInt UndefSrcElts; + APInt DemandedSrc = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + if (isSplatValue(Src, DemandedSrc, UndefSrcElts)) { + UndefElts = UndefSrcElts.extractBits(NumElts, Idx); + return true; + } + } + break; + } + case ISD::ADD: + case ISD::SUB: + case ISD::AND: { + APInt UndefLHS, UndefRHS; + SDValue LHS = V.getOperand(0); + SDValue RHS = V.getOperand(1); + if (isSplatValue(LHS, DemandedElts, UndefLHS) && + isSplatValue(RHS, DemandedElts, UndefRHS)) { + UndefElts = UndefLHS | UndefRHS; + return true; + } + break; + } + } + + return false; +} + +/// Helper wrapper to main isSplatValue function. +bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) { + EVT VT = V.getValueType(); + assert(VT.isVector() && "Vector type expected"); + unsigned NumElts = VT.getVectorNumElements(); + + APInt UndefElts; + APInt DemandedElts = APInt::getAllOnesValue(NumElts); + return isSplatValue(V, DemandedElts, UndefElts) && + (AllowUndefs || !UndefElts); +} + /// Helper function that checks to see if a node is a constant or a /// build vector of splat constants at least within the demanded elts. static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N, -- cgit v1.2.3