diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2019-02-25 16:31:58 +0000 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2019-02-25 16:31:58 +0000 |
commit | 80d0e9c563d9628da8d194cbb5096eabffb07bff (patch) | |
tree | a9c6d1f98fa7e773b67468f6719cd69887c31489 /llvm/lib/CodeGen | |
parent | 8a7f4c98913b4fd6253768a50116eedf99b79832 (diff) | |
download | bcm5719-llvm-80d0e9c563d9628da8d194cbb5096eabffb07bff.tar.gz bcm5719-llvm-80d0e9c563d9628da8d194cbb5096eabffb07bff.zip |
[SelectionDAG] Add demanded elts variants to isConstOrConstSplat helpers. NFCI.
These helpers extend the existing isConstOrConstSplat helper checks to support DemandedElts masks as well.
We already had a local version of this in SelectionDAG that computeKnownBits/ComputeNumSignBits made use of, but this adds the functionality directly to the BuildVectorSDNode node and extends isConstOrConstSplat etc. to use that.
This will allow us to reuse the functionality in SimplifyDemandedVectorElts/SimplifyDemandedBits.
Differential Revision: https://reviews.llvm.org/D58503
llvm-svn: 354797
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 111 |
1 files changed, 74 insertions, 37 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index b9be29be61d..c0bcb7ab1ff 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2253,30 +2253,6 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) { (AllowUndefs || !UndefElts); } -/// Helper function that checks to see if a node is a constant or a -/// build vector of splat constants at least within the demanded elts. -static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N, - const APInt &DemandedElts) { - if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) - return CN; - if (N.getOpcode() != ISD::BUILD_VECTOR) - return nullptr; - EVT VT = N.getValueType(); - ConstantSDNode *Cst = nullptr; - unsigned NumElts = VT.getVectorNumElements(); - assert(DemandedElts.getBitWidth() == NumElts && "Unexpected vector size"); - for (unsigned i = 0; i != NumElts; ++i) { - if (!DemandedElts[i]) - continue; - ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(i)); - if (!C || (Cst && Cst->getAPIntValue() != C->getAPIntValue()) || - C->getValueType(0) != VT.getScalarType()) - return nullptr; - Cst = C; - } - return Cst; -} - /// If a SHL/SRA/SRL node has a constant or splat constant shift amount that /// is less than the element bit-width of the shift node, return it. static const APInt *getValidShiftAmountConstant(SDValue V) { @@ -2717,8 +2693,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, break; case ISD::FSHL: case ISD::FSHR: - if (ConstantSDNode *C = - isConstOrDemandedConstSplat(Op.getOperand(2), DemandedElts)) { + if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) { unsigned Amt = C->getAPIntValue().urem(BitWidth); // For fshl, 0-shift returns the 1st arg. @@ -3155,10 +3130,10 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, // the minimum of the clamp min/max range. bool IsMax = (Opcode == ISD::SMAX); ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr; - if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts))) + if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts))) if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX)) - CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1), - DemandedElts); + CstHigh = + isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts); if (CstLow && CstHigh) { if (!IsMax) std::swap(CstLow, CstHigh); @@ -3439,7 +3414,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); // SRA X, C -> adds C sign bits. if (ConstantSDNode *C = - isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) { + isConstOrConstSplat(Op.getOperand(1), DemandedElts)) { APInt ShiftVal = C->getAPIntValue(); ShiftVal += Tmp; Tmp = ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue(); @@ -3447,7 +3422,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, return Tmp; case ISD::SHL: if (ConstantSDNode *C = - isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) { + isConstOrConstSplat(Op.getOperand(1), DemandedElts)) { // shl destroys sign bits. Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); if (C->getAPIntValue().uge(VTBits) || // Bad shift. @@ -3487,10 +3462,10 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, // the minimum of the clamp min/max range. bool IsMax = (Opcode == ISD::SMAX); ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr; - if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts))) + if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts))) if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX)) - CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1), - DemandedElts); + CstHigh = + isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts); if (CstLow && CstHigh) { if (!IsMax) std::swap(CstLow, CstHigh); @@ -8593,6 +8568,24 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) { return nullptr; } +ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts, + bool AllowUndefs) { + if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) { + BitVector UndefElements; + ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements); + + // BuildVectors can truncate their operands. Ignore that case here. + if (CN && (UndefElements.none() || AllowUndefs) && + CN->getValueType(0) == N.getValueType().getScalarType()) + return CN; + } + + return nullptr; +} + ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) { if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N)) return CN; @@ -8607,6 +8600,23 @@ ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) { return nullptr; } +ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, + const APInt &DemandedElts, + bool AllowUndefs) { + if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) { + BitVector UndefElements; + ConstantFPSDNode *CN = + BV->getConstantFPSplatNode(DemandedElts, &UndefElements); + if (CN && (UndefElements.none() || AllowUndefs)) + return CN; + } + + return nullptr; +} + bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) { // TODO: may want to use peekThroughBitcast() here. ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs); @@ -9193,13 +9203,20 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, APInt &SplatUndef, return true; } -SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const { +SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts, + BitVector *UndefElements) const { if (UndefElements) { UndefElements->clear(); UndefElements->resize(getNumOperands()); } + assert(getNumOperands() == DemandedElts.getBitWidth() && + "Unexpected vector size"); + if (!DemandedElts) + return SDValue(); SDValue Splatted; for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + if (!DemandedElts[i]) + continue; SDValue Op = getOperand(i); if (Op.isUndef()) { if (UndefElements) @@ -9212,20 +9229,40 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const { } if (!Splatted) { - assert(getOperand(0).isUndef() && + unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros(); + assert(getOperand(FirstDemandedIdx).isUndef() && "Can only have a splat without a constant for all undefs."); - return getOperand(0); + return getOperand(FirstDemandedIdx); } return Splatted; } +SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const { + APInt DemandedElts = APInt::getAllOnesValue(getNumOperands()); + return getSplatValue(DemandedElts, UndefElements); +} + +ConstantSDNode * +BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts, + BitVector *UndefElements) const { + return dyn_cast_or_null<ConstantSDNode>( + getSplatValue(DemandedElts, UndefElements)); +} + ConstantSDNode * BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const { return dyn_cast_or_null<ConstantSDNode>(getSplatValue(UndefElements)); } ConstantFPSDNode * +BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts, + BitVector *UndefElements) const { + return dyn_cast_or_null<ConstantFPSDNode>( + getSplatValue(DemandedElts, UndefElements)); +} + +ConstantFPSDNode * BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const { return dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements)); } |