diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2018-12-17 18:43:43 +0000 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2018-12-17 18:43:43 +0000 |
commit | 9274f17a5ec5869e944d77b9f9c81c5f8063f360 (patch) | |
tree | 687a7d88d1942bb53dffd91c2586706414f991dd /llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | |
parent | 077a0aff164a651ac439a035571dfe6af85f0221 (diff) | |
download | bcm5719-llvm-9274f17a5ec5869e944d77b9f9c81c5f8063f360.tar.gz bcm5719-llvm-9274f17a5ec5869e944d77b9f9c81c5f8063f360.zip |
[TargetLowering] Add DemandedElts mask to SimplifyDemandedBits (PR40000)
This is an initial patch to add the necessary support for a DemandedElts argument to SimplifyDemandedBits, more closely matching computeKnownBits and to help improve vector codegen.
I've added only a small amount of the changes necessary to get at least one test to update - a lot more can be done but I'd like to add these methodically with proper test coverage, at the same time the hope is to slowly move some/all of SimplifyDemandedVectorElts into SimplifyDemandedBits as well.
Differential Revision: https://reviews.llvm.org/D55768
llvm-svn: 349374
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 162 |
1 files changed, 120 insertions, 42 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 310eee2fb03..0c9ef5b4da9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -496,23 +496,41 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, return Simplified; } +bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, + KnownBits &Known, + TargetLoweringOpt &TLO, + unsigned Depth, + bool AssumeSingleUse) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth, + AssumeSingleUse); +} + /// Look at Op. At this point, we know that only the OriginalDemandedBits of the /// result of Op are ever used downstream. If we can use this information to /// simplify Op, create a new simplified DAG node and return true, returning the /// original and new nodes in Old and New. Otherwise, analyze the expression and /// return a mask of Known bits for the expression (used to simplify the /// caller). The Known bits may only be accurate for those bits in the -/// DemandedMask. -bool TargetLowering::SimplifyDemandedBits(SDValue Op, - const APInt &OriginalDemandedBits, - KnownBits &Known, - TargetLoweringOpt &TLO, - unsigned Depth, - bool AssumeSingleUse) const { +/// OriginalDemandedBits and OriginalDemandedElts. +bool TargetLowering::SimplifyDemandedBits( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth, bool AssumeSingleUse) const { unsigned BitWidth = OriginalDemandedBits.getBitWidth(); assert(Op.getScalarValueSizeInBits() == BitWidth && "Mask size mismatches value type size!"); + + unsigned NumElts = OriginalDemandedElts.getBitWidth(); + assert((!Op.getValueType().isVector() || + NumElts == Op.getValueType().getVectorNumElements()) && + "Unexpected vector size"); + APInt DemandedBits = OriginalDemandedBits; + APInt DemandedElts = OriginalDemandedElts; SDLoc dl(Op); auto &DL = TLO.DAG.getDataLayout(); @@ -532,18 +550,19 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, if (Depth != 0) { // If not at the root, Just compute the Known bits to // simplify things downstream. - TLO.DAG.computeKnownBits(Op, Known, Depth); + TLO.DAG.computeKnownBits(Op, Known, DemandedElts, Depth); return false; } // If this is the root being simplified, allow it to have multiple uses, - // just set the DemandedBits to all bits. + // just set the DemandedBits/Elts to all bits. DemandedBits = APInt::getAllOnesValue(BitWidth); - } else if (OriginalDemandedBits == 0) { - // Not demanding any bits from Op. + DemandedElts = APInt::getAllOnesValue(NumElts); + } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) { + // Not demanding any bits/elts from Op. if (!Op.isUndef()) return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); return false; - } else if (Depth == 6) { // Limit search depth. + } else if (Depth == 6) { // Limit search depth. return false; } @@ -573,18 +592,71 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, Known.One &= Known2.One; Known.Zero &= Known2.Zero; } - return false; // Don't fall through, will infinitely loop. - case ISD::CONCAT_VECTORS: + return false; // Don't fall through, will infinitely loop. + case ISD::CONCAT_VECTORS: { Known.Zero.setAllBits(); Known.One.setAllBits(); - for (SDValue SrcOp : Op->ops()) { - if (SimplifyDemandedBits(SrcOp, DemandedBits, Known2, TLO, Depth + 1)) + EVT SubVT = Op.getOperand(0).getValueType(); + unsigned NumSubVecs = Op.getNumOperands(); + unsigned NumSubElts = SubVT.getVectorNumElements(); + for (unsigned i = 0; i != NumSubVecs; ++i) { + APInt DemandedSubElts = + DemandedElts.extractBits(NumSubElts, i * NumSubElts); + if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts, + Known2, TLO, Depth + 1)) return true; - // Known bits are the values that are shared by every subvector. - Known.One &= Known2.One; - Known.Zero &= Known2.Zero; + // Known bits are shared by every demanded subvector element. + if (!!DemandedSubElts) { + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + } + break; + } + case ISD::VECTOR_SHUFFLE: { + ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask(); + + // Collect demanded elements from shuffle operands.. + APInt DemandedLHS(NumElts, 0); + APInt DemandedRHS(NumElts, 0); + for (unsigned i = 0; i != NumElts; ++i) { + if (!DemandedElts[i]) + continue; + int M = ShuffleMask[i]; + if (M < 0) { + // For UNDEF elements, we don't know anything about the common state of + // the shuffle result. + DemandedLHS.clearAllBits(); + DemandedRHS.clearAllBits(); + break; + } + assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range"); + if (M < (int)NumElts) + DemandedLHS.setBit(M); + else + DemandedRHS.setBit(M - NumElts); + } + + if (!!DemandedLHS || !!DemandedRHS) { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + if (!!DemandedLHS) { + if (SimplifyDemandedBits(Op.getOperand(0), DemandedBits, DemandedLHS, + Known2, TLO, Depth + 1)) + return true; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + if (!!DemandedRHS) { + if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedRHS, + Known2, TLO, Depth + 1)) + return true; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } } break; + } case ISD::AND: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -596,7 +668,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1)) { KnownBits LHSKnown; // Do not increment Depth here; that can cause an infinite loop. - TLO.DAG.computeKnownBits(Op0, LHSKnown, Depth); + TLO.DAG.computeKnownBits(Op0, LHSKnown, DemandedElts, Depth); // If the LHS already has zeros where RHSC does, this 'and' is dead. if ((LHSKnown.Zero & DemandedBits) == (~RHSC->getAPIntValue() & DemandedBits)) @@ -619,10 +691,10 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, } } - if (SimplifyDemandedBits(Op1, DemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, Known2, TLO, + if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts, Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -653,10 +725,11 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (SimplifyDemandedBits(Op1, DemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, Known2, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts, Known2, TLO, + Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -683,10 +756,10 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (SimplifyDemandedBits(Op1, DemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, DemandedBits, Known2, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -840,7 +913,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, } } - if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), Known, TLO, + if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), DemandedElts, Known, TLO, Depth + 1)) return true; @@ -935,7 +1008,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, } // Compute the new bits that are at the top now. - if (SimplifyDemandedBits(Op0, InDemandedMask, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -974,7 +1047,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, if (DemandedBits.countLeadingZeros() < ShAmt) InDemandedMask.setSignBit(); - if (SimplifyDemandedBits(Op0, InDemandedMask, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -1221,18 +1294,26 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, break; } case ISD::EXTRACT_VECTOR_ELT: { - // Demand the bits from every vector element. SDValue Src = Op.getOperand(0); + SDValue Idx = Op.getOperand(1); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); unsigned EltBitWidth = Src.getScalarValueSizeInBits(); + // Demand the bits from every vector element without a constant index. + APInt DemandedSrcElts = APInt::getAllOnesValue(NumSrcElts); + if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) + if (CIdx->getAPIntValue().ult(NumSrcElts)) + DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue()); + // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know // anything about the extended bits. APInt DemandedSrcBits = DemandedBits; if (BitWidth > EltBitWidth) DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth); - if (SimplifyDemandedBits(Src, DemandedSrcBits, Known2, TLO, Depth + 1)) - return true; + if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO, + Depth + 1)) + return true; Known = Known2; if (BitWidth > EltBitWidth) @@ -1313,8 +1394,8 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1); unsigned DemandedBitsLZ = DemandedBits.countLeadingZeros(); APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ); - if (SimplifyDemandedBits(Op0, LoMask, Known2, TLO, Depth + 1) || - SimplifyDemandedBits(Op1, LoMask, Known2, TLO, Depth + 1) || + if (SimplifyDemandedBits(Op0, LoMask, DemandedElts, Known2, TLO, Depth + 1) || + SimplifyDemandedBits(Op1, LoMask, DemandedElts, Known2, TLO, Depth + 1) || // See if the operation should be performed at a smaller bit width. ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) { SDNodeFlags Flags = Op.getNode()->getFlags(); @@ -1354,14 +1435,14 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, } default: if (Op.getOpcode() >= ISD::BUILTIN_OP_END) { - if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, Known, TLO, - Depth)) + if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts, + Known, TLO, Depth)) return true; break; } // Just use computeKnownBits to compute output bits. - TLO.DAG.computeKnownBits(Op, Known, Depth); + TLO.DAG.computeKnownBits(Op, Known, DemandedElts, Depth); break; } @@ -1887,8 +1968,8 @@ bool TargetLowering::SimplifyDemandedVectorEltsForTargetNode( } bool TargetLowering::SimplifyDemandedBitsForTargetNode( - SDValue Op, const APInt &DemandedBits, KnownBits &Known, - TargetLoweringOpt &TLO, unsigned Depth) const { + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const { assert((Op.getOpcode() >= ISD::BUILTIN_OP_END || Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || Op.getOpcode() == ISD::INTRINSIC_W_CHAIN || @@ -1896,9 +1977,6 @@ bool TargetLowering::SimplifyDemandedBitsForTargetNode( "Should use SimplifyDemandedBits if you don't know whether Op" " is a target node!"); EVT VT = Op.getValueType(); - APInt DemandedElts = VT.isVector() - ? APInt::getAllOnesValue(VT.getVectorNumElements()) - : APInt(1, 1); computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth); return false; } |