diff options
author | Roman Lebedev <lebedev.ri@gmail.com> | 2019-07-20 16:33:15 +0000 |
---|---|---|
committer | Roman Lebedev <lebedev.ri@gmail.com> | 2019-07-20 16:33:15 +0000 |
commit | cd9b19484b63c373de00aaacc4ee4da81247c17c (patch) | |
tree | 1754f3321bda193afc6052bcefc99bb463d7d22e /llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | |
parent | adec0f22524ad869a5925923672a6d41251ad7c2 (diff) | |
download | bcm5719-llvm-cd9b19484b63c373de00aaacc4ee4da81247c17c.tar.gz bcm5719-llvm-cd9b19484b63c373de00aaacc4ee4da81247c17c.zip |
[Codegen][SelectionDAG] X u% C == 0 fold: non-splat vector improvements
Summary:
Four things here:
1. Generalize the fold to handle non-splat divisors. Reasonably trivial.
2. Unban power-of-two divisors. I don't see any reason why they should
be illegal.
* There is no ban in Hacker's Delight
* I think the ban came from the same bug that caused the miscompile
in the base patch - in `floor((2^W - 1) / D)` we were dividing by
`D0` instead of `D`, and we **were** ensuring that `D0` is not `1`,
which made sense.
3. Unban `1` divisors. I no longer believe Hacker's Delight actually says
that the fold is invalid for `D = 0`. Further considerations:
* We know that
* `(X u% 1) == 0` can be constant-folded to `1`,
* `(X u% 1) != 0` can be constant-folded to `0`,
* Also, we know that
* `X u<= -1` can be constant-folded to `1`,
* `X u> -1` can be constant-folded to `0`,
* https://godbolt.org/z/7jnZJX https://rise4fun.com/Alive/oF6p
* We know will end up with the following:
`(setule/setugt (rotr (mul N, P), K), Q)`
* Therefore, for given new DAG nodes and comparison predicates
(`ule`/`ugt`), we will still produce the correct answer if:
`Q` is a all-ones constant; and both `P` and `K` are *anything*
other than `undef`.
* The fold will indeed produce `Q = all-ones`.
4. Try to re-splat the `P` and `K` vectors - we don't care about
their values for the lanes where divisor was `1`.
Reviewers: RKSimon, hermord, craig.topper, spatel, xbolva00
Reviewed By: RKSimon
Subscribers: hiraditya, javed.absar, dexonsmith, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D63963
llvm-svn: 366637
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 167 |
1 files changed, 132 insertions, 35 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index b260cd91d46..bab99b2e029 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -4455,6 +4455,34 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, return DAG.getSelect(dl, VT, IsOne, N0, Q); } +/// If all values in Values that *don't* match the predicate are same 'splat' +/// value, then replace all values with that splat value. +/// Else, if AlternativeReplacement was provided, then replace all values that +/// do match predicate with AlternativeReplacement value. +static void +turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values, + std::function<bool(SDValue)> Predicate, + SDValue AlternativeReplacement = SDValue()) { + SDValue Replacement; + // Is there a value for which the Predicate does *NOT* match? What is it? + auto SplatValue = llvm::find_if_not(Values, Predicate); + if (SplatValue != Values.end()) { + // Does Values consist only of SplatValue's and values matching Predicate? + if (llvm::all_of(Values, [Predicate, SplatValue](SDValue Value) { + return Value == *SplatValue || Predicate(Value); + })) // Then we shall replace values matching predicate with SplatValue. + Replacement = *SplatValue; + } + if (!Replacement) { + // Oops, we did not find the "baseline" splat value. + if (!AlternativeReplacement) + return; // Nothing to do. + // Let's replace with provided value then. + Replacement = AlternativeReplacement; + } + std::replace_if(Values.begin(), Values.end(), Predicate, Replacement); +} + /// Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE /// where the divisor is constant and the comparison target is zero, /// return a DAG expression that will generate the same comparison result @@ -4482,74 +4510,143 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, DAGCombinerInfo &DCI, const SDLoc &DL, SmallVectorImpl<SDNode *> &Created) const { // fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q) - // - D must be constant with D = D0 * 2^K where D0 is odd and D0 != 1 + // - D must be constant, with D = D0 * 2^K where D0 is odd // - P is the multiplicative inverse of D0 modulo 2^W // - Q = floor((2^W - 1) / D0) // where W is the width of the common type of N and D. assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Only applicable for (in)equality comparisons."); + SelectionDAG &DAG = DCI.DAG; + EVT VT = REMNode.getValueType(); + EVT SVT = VT.getScalarType(); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + EVT ShSVT = ShVT.getScalarType(); // If MUL is unavailable, we cannot proceed in any case. if (!isOperationLegalOrCustom(ISD::MUL, VT)) return SDValue(); - // TODO: Add non-uniform constant support. - ConstantSDNode *Divisor = isConstOrConstSplat(REMNode->getOperand(1)); + // TODO: Could support comparing with non-zero too. ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode); - if (!Divisor || !CompTarget || Divisor->isNullValue() || - !CompTarget->isNullValue()) + if (!CompTarget || !CompTarget->isNullValue()) return SDValue(); - const APInt &D = Divisor->getAPIntValue(); + bool HadOneDivisor = false; + bool AllDivisorsAreOnes = true; + bool HadEvenDivisor = false; + bool AllDivisorsArePowerOfTwo = true; + SmallVector<SDValue, 16> PAmts, KAmts, QAmts; + + auto BuildUREMPattern = [&](ConstantSDNode *C) { + // Division by 0 is UB. Leave it to be constant-folded elsewhere. + if (C->isNullValue()) + return false; - // Decompose D into D0 * 2^K - unsigned K = D.countTrailingZeros(); - bool DivisorIsEven = (K != 0); - APInt D0 = D.lshr(K); + const APInt &D = C->getAPIntValue(); + // If all divisors are ones, we will prefer to avoid the fold. + HadOneDivisor |= D.isOneValue(); + AllDivisorsAreOnes &= D.isOneValue(); + + // Decompose D into D0 * 2^K + unsigned K = D.countTrailingZeros(); + assert((!D.isOneValue() || (K == 0)) && "For divisor '1' we won't rotate."); + APInt D0 = D.lshr(K); + + // D is even if it has trailing zeros. + HadEvenDivisor |= (K != 0); + // D is a power-of-two if D0 is one. + // If all divisors are power-of-two, we will prefer to avoid the fold. + AllDivisorsArePowerOfTwo &= D0.isOneValue(); + + // P = inv(D0, 2^W) + // 2^W requires W + 1 bits, so we have to extend and then truncate. + unsigned W = D.getBitWidth(); + APInt P = D0.zext(W + 1) + .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) + .trunc(W); + assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable + assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check."); + + // Q = floor((2^W - 1) / D) + APInt Q = APInt::getAllOnesValue(W).udiv(D); + + assert(APInt::getAllOnesValue(ShSVT.getSizeInBits()).ugt(K) && + "We are expecting that K is always less than all-ones for ShSVT"); + + // If the divisor is 1 the result can be constant-folded. + if (D.isOneValue()) { + // Set P and K amount to a bogus values so we can try to splat them. + P = 0; + K = -1; + assert(Q.isAllOnesValue() && + "Expecting all-ones comparison for one divisor"); + } - // The fold is invalid when D0 == 1. - // This is reachable because visitSetCC happens before visitREM. - if (D0.isOneValue()) + PAmts.push_back(DAG.getConstant(P, DL, SVT)); + KAmts.push_back( + DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT)); + QAmts.push_back(DAG.getConstant(Q, DL, SVT)); + return true; + }; + + SDValue N = REMNode.getOperand(0); + SDValue D = REMNode.getOperand(1); + + // Collect the values from each element. + if (!ISD::matchUnaryPredicate(D, BuildUREMPattern)) return SDValue(); - // P = inv(D0, 2^W) - // 2^W requires W + 1 bits, so we have to extend and then truncate. - unsigned W = D.getBitWidth(); - APInt P = D0.zext(W + 1) - .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) - .trunc(W); - assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable - assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check."); + // If this is a urem by a one, avoid the fold since it can be constant-folded. + if (AllDivisorsAreOnes) + return SDValue(); - // Q = floor((2^W - 1) / D) - APInt Q = APInt::getAllOnesValue(W).udiv(D); + // If this is a urem by a powers-of-two, avoid the fold since it can be + // best implemented as a bit test. + if (AllDivisorsArePowerOfTwo) + return SDValue(); - SelectionDAG &DAG = DCI.DAG; + SDValue PVal, KVal, QVal; + if (VT.isVector()) { + if (HadOneDivisor) { + // Try to turn PAmts into a splat, since we don't care about the values + // that are currently '0'. If we can't, just keep '0'`s. + turnVectorIntoSplatVector(PAmts, isNullConstant); + // Try to turn KAmts into a splat, since we don't care about the values + // that are currently '-1'. If we can't, change them to '0'`s. + turnVectorIntoSplatVector(KAmts, isAllOnesConstant, + DAG.getConstant(0, DL, ShSVT)); + } + + PVal = DAG.getBuildVector(VT, DL, PAmts); + KVal = DAG.getBuildVector(ShVT, DL, KAmts); + QVal = DAG.getBuildVector(VT, DL, QAmts); + } else { + PVal = PAmts[0]; + KVal = KAmts[0]; + QVal = QAmts[0]; + } - SDValue PVal = DAG.getConstant(P, DL, VT); - SDValue QVal = DAG.getConstant(Q, DL, VT); // (mul N, P) - SDValue Op1 = DAG.getNode(ISD::MUL, DL, VT, REMNode->getOperand(0), PVal); - Created.push_back(Op1.getNode()); + SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal); + Created.push_back(Op0.getNode()); - // Rotate right only if D was even. - if (DivisorIsEven) { + // Rotate right only if any divisor was even. We avoid rotates for all-odd + // divisors as a performance improvement, since rotating by 0 is a no-op. + if (HadEvenDivisor) { // We need ROTR to do this. if (!isOperationLegalOrCustom(ISD::ROTR, VT)) return SDValue(); - SDValue ShAmt = - DAG.getConstant(K, DL, getShiftAmountTy(VT, DAG.getDataLayout())); SDNodeFlags Flags; Flags.setExact(true); // UREM: (rotr (mul N, P), K) - Op1 = DAG.getNode(ISD::ROTR, DL, VT, Op1, ShAmt, Flags); - Created.push_back(Op1.getNode()); + Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal, Flags); + Created.push_back(Op0.getNode()); } // UREM: (setule/setugt (rotr (mul N, P), K), Q) - return DAG.getSetCC(DL, SETCCVT, Op1, QVal, + return DAG.getSetCC(DL, SETCCVT, Op0, QVal, ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT)); } |