diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 37 |
1 files changed, 28 insertions, 9 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 9a9ac690aaa..6f563c2a0ee 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -4943,7 +4943,7 @@ SDValue TargetLowering::buildUREMEqFold(EVT SETCCVT, SDValue REMNode, ISD::CondCode Cond, DAGCombinerInfo &DCI, const SDLoc &DL) const { - SmallVector<SDNode *, 4> Built; + SmallVector<SDNode *, 5> Built; if (SDValue Folded = prepareUREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond, DCI, DL, Built)) { for (SDNode *N : Built) @@ -4978,6 +4978,8 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, if (!isOperationLegalOrCustom(ISD::MUL, VT)) return SDValue(); + bool ComparingWithAllZeros = true; + bool AllComparisonsWithNonZerosAreTautological = true; bool HadTautologicalLanes = false; bool AllLanesAreTautological = true; bool HadEvenDivisor = false; @@ -4993,6 +4995,8 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, const APInt &D = CDiv->getAPIntValue(); const APInt &Cmp = CCmp->getAPIntValue(); + ComparingWithAllZeros &= Cmp.isNullValue(); + // x u% C1` is *always* less than C1. So given `x u% C1 == C2`, // if C2 is not less than C1, the comparison is always false. // But we will only be able to produce the comparison that will give the @@ -5000,12 +5004,6 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, bool TautologicalInvertedLane = D.ule(Cmp); HadTautologicalInvertedLanes |= TautologicalInvertedLane; - // If we are checking that remainder is something smaller than the divisor, - // then this comparison isn't tautological. For now this is not handled, - // other than the comparison that remainder is zero. - if (!Cmp.isNullValue() && !TautologicalInvertedLane) - return false; - // If all lanes are tautological (either all divisors are ones, or divisor // is not greater than the constant we are comparing with), // we will prefer to avoid the fold. @@ -5013,6 +5011,12 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, HadTautologicalLanes |= TautologicalLane; AllLanesAreTautological &= TautologicalLane; + // If we are comparing with non-zero, we need'll need to subtract said + // comparison value from the LHS. But there is no point in doing that if + // every lane where we are comparing with non-zero is tautological.. + if (!Cmp.isNullValue()) + AllComparisonsWithNonZerosAreTautological &= TautologicalLane; + // Decompose D into D0 * 2^K unsigned K = D.countTrailingZeros(); assert((!D.isOneValue() || (K == 0)) && "For divisor '1' we won't rotate."); @@ -5033,8 +5037,15 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check."); - // Q = floor((2^W - 1) / D) - APInt Q = APInt::getAllOnesValue(W).udiv(D); + // Q = floor((2^W - 1) u/ D) + // R = ((2^W - 1) u% D) + APInt Q, R; + APInt::udivrem(APInt::getAllOnesValue(W), D, Q, R); + + // If we are comparing with zero, then that comparison constant is okay, + // else it may need to be one less than that. + if (Cmp.ugt(R)) + Q -= 1; assert(APInt::getAllOnesValue(ShSVT.getSizeInBits()).ugt(K) && "We are expecting that K is always less than all-ones for ShSVT"); @@ -5093,6 +5104,14 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, QVal = QAmts[0]; } + if (!ComparingWithAllZeros && !AllComparisonsWithNonZerosAreTautological) { + if (!isOperationLegalOrCustom(ISD::SUB, VT)) + return SDValue(); // FIXME: Could/should use `ISD::ADD`? + assert(CompTargetNode.getValueType() == N.getValueType() && + "Expecting that the types on LHS and RHS of comparisons match."); + N = DAG.getNode(ISD::SUB, DL, VT, N, CompTargetNode); + } + // (mul N, P) SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal); Created.push_back(Op0.getNode()); |

