diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 167 |
1 files changed, 132 insertions, 35 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index b260cd91d46..bab99b2e029 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -4455,6 +4455,34 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, return DAG.getSelect(dl, VT, IsOne, N0, Q); } +/// If all values in Values that *don't* match the predicate are same 'splat' +/// value, then replace all values with that splat value. +/// Else, if AlternativeReplacement was provided, then replace all values that +/// do match predicate with AlternativeReplacement value. +static void +turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values, + std::function<bool(SDValue)> Predicate, + SDValue AlternativeReplacement = SDValue()) { + SDValue Replacement; + // Is there a value for which the Predicate does *NOT* match? What is it? + auto SplatValue = llvm::find_if_not(Values, Predicate); + if (SplatValue != Values.end()) { + // Does Values consist only of SplatValue's and values matching Predicate? + if (llvm::all_of(Values, [Predicate, SplatValue](SDValue Value) { + return Value == *SplatValue || Predicate(Value); + })) // Then we shall replace values matching predicate with SplatValue. + Replacement = *SplatValue; + } + if (!Replacement) { + // Oops, we did not find the "baseline" splat value. + if (!AlternativeReplacement) + return; // Nothing to do. + // Let's replace with provided value then. + Replacement = AlternativeReplacement; + } + std::replace_if(Values.begin(), Values.end(), Predicate, Replacement); +} + /// Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE /// where the divisor is constant and the comparison target is zero, /// return a DAG expression that will generate the same comparison result @@ -4482,74 +4510,143 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, DAGCombinerInfo &DCI, const SDLoc &DL, SmallVectorImpl<SDNode *> &Created) const { // fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q) - // - D must be constant with D = D0 * 2^K where D0 is odd and D0 != 1 + // - D must be constant, with D = D0 * 2^K where D0 is odd // - P is the multiplicative inverse of D0 modulo 2^W // - Q = floor((2^W - 1) / D0) // where W is the width of the common type of N and D. assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Only applicable for (in)equality comparisons."); + SelectionDAG &DAG = DCI.DAG; + EVT VT = REMNode.getValueType(); + EVT SVT = VT.getScalarType(); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + EVT ShSVT = ShVT.getScalarType(); // If MUL is unavailable, we cannot proceed in any case. if (!isOperationLegalOrCustom(ISD::MUL, VT)) return SDValue(); - // TODO: Add non-uniform constant support. - ConstantSDNode *Divisor = isConstOrConstSplat(REMNode->getOperand(1)); + // TODO: Could support comparing with non-zero too. ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode); - if (!Divisor || !CompTarget || Divisor->isNullValue() || - !CompTarget->isNullValue()) + if (!CompTarget || !CompTarget->isNullValue()) return SDValue(); - const APInt &D = Divisor->getAPIntValue(); + bool HadOneDivisor = false; + bool AllDivisorsAreOnes = true; + bool HadEvenDivisor = false; + bool AllDivisorsArePowerOfTwo = true; + SmallVector<SDValue, 16> PAmts, KAmts, QAmts; + + auto BuildUREMPattern = [&](ConstantSDNode *C) { + // Division by 0 is UB. Leave it to be constant-folded elsewhere. + if (C->isNullValue()) + return false; - // Decompose D into D0 * 2^K - unsigned K = D.countTrailingZeros(); - bool DivisorIsEven = (K != 0); - APInt D0 = D.lshr(K); + const APInt &D = C->getAPIntValue(); + // If all divisors are ones, we will prefer to avoid the fold. + HadOneDivisor |= D.isOneValue(); + AllDivisorsAreOnes &= D.isOneValue(); + + // Decompose D into D0 * 2^K + unsigned K = D.countTrailingZeros(); + assert((!D.isOneValue() || (K == 0)) && "For divisor '1' we won't rotate."); + APInt D0 = D.lshr(K); + + // D is even if it has trailing zeros. + HadEvenDivisor |= (K != 0); + // D is a power-of-two if D0 is one. + // If all divisors are power-of-two, we will prefer to avoid the fold. + AllDivisorsArePowerOfTwo &= D0.isOneValue(); + + // P = inv(D0, 2^W) + // 2^W requires W + 1 bits, so we have to extend and then truncate. + unsigned W = D.getBitWidth(); + APInt P = D0.zext(W + 1) + .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) + .trunc(W); + assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable + assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check."); + + // Q = floor((2^W - 1) / D) + APInt Q = APInt::getAllOnesValue(W).udiv(D); + + assert(APInt::getAllOnesValue(ShSVT.getSizeInBits()).ugt(K) && + "We are expecting that K is always less than all-ones for ShSVT"); + + // If the divisor is 1 the result can be constant-folded. + if (D.isOneValue()) { + // Set P and K amount to a bogus values so we can try to splat them. + P = 0; + K = -1; + assert(Q.isAllOnesValue() && + "Expecting all-ones comparison for one divisor"); + } - // The fold is invalid when D0 == 1. - // This is reachable because visitSetCC happens before visitREM. - if (D0.isOneValue()) + PAmts.push_back(DAG.getConstant(P, DL, SVT)); + KAmts.push_back( + DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT)); + QAmts.push_back(DAG.getConstant(Q, DL, SVT)); + return true; + }; + + SDValue N = REMNode.getOperand(0); + SDValue D = REMNode.getOperand(1); + + // Collect the values from each element. + if (!ISD::matchUnaryPredicate(D, BuildUREMPattern)) return SDValue(); - // P = inv(D0, 2^W) - // 2^W requires W + 1 bits, so we have to extend and then truncate. - unsigned W = D.getBitWidth(); - APInt P = D0.zext(W + 1) - .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) - .trunc(W); - assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable - assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check."); + // If this is a urem by a one, avoid the fold since it can be constant-folded. + if (AllDivisorsAreOnes) + return SDValue(); - // Q = floor((2^W - 1) / D) - APInt Q = APInt::getAllOnesValue(W).udiv(D); + // If this is a urem by a powers-of-two, avoid the fold since it can be + // best implemented as a bit test. + if (AllDivisorsArePowerOfTwo) + return SDValue(); - SelectionDAG &DAG = DCI.DAG; + SDValue PVal, KVal, QVal; + if (VT.isVector()) { + if (HadOneDivisor) { + // Try to turn PAmts into a splat, since we don't care about the values + // that are currently '0'. If we can't, just keep '0'`s. + turnVectorIntoSplatVector(PAmts, isNullConstant); + // Try to turn KAmts into a splat, since we don't care about the values + // that are currently '-1'. If we can't, change them to '0'`s. + turnVectorIntoSplatVector(KAmts, isAllOnesConstant, + DAG.getConstant(0, DL, ShSVT)); + } + + PVal = DAG.getBuildVector(VT, DL, PAmts); + KVal = DAG.getBuildVector(ShVT, DL, KAmts); + QVal = DAG.getBuildVector(VT, DL, QAmts); + } else { + PVal = PAmts[0]; + KVal = KAmts[0]; + QVal = QAmts[0]; + } - SDValue PVal = DAG.getConstant(P, DL, VT); - SDValue QVal = DAG.getConstant(Q, DL, VT); // (mul N, P) - SDValue Op1 = DAG.getNode(ISD::MUL, DL, VT, REMNode->getOperand(0), PVal); - Created.push_back(Op1.getNode()); + SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal); + Created.push_back(Op0.getNode()); - // Rotate right only if D was even. - if (DivisorIsEven) { + // Rotate right only if any divisor was even. We avoid rotates for all-odd + // divisors as a performance improvement, since rotating by 0 is a no-op. + if (HadEvenDivisor) { // We need ROTR to do this. if (!isOperationLegalOrCustom(ISD::ROTR, VT)) return SDValue(); - SDValue ShAmt = - DAG.getConstant(K, DL, getShiftAmountTy(VT, DAG.getDataLayout())); SDNodeFlags Flags; Flags.setExact(true); // UREM: (rotr (mul N, P), K) - Op1 = DAG.getNode(ISD::ROTR, DL, VT, Op1, ShAmt, Flags); - Created.push_back(Op1.getNode()); + Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal, Flags); + Created.push_back(Op0.getNode()); } // UREM: (setule/setugt (rotr (mul N, P), K), Q) - return DAG.getSetCC(DL, SETCCVT, Op1, QVal, + return DAG.getSetCC(DL, SETCCVT, Op0, QVal, ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT)); } |