diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index f2a487bad0f..ab4578eeeb1 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4585,6 +4585,20 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, return nullptr; } +// if Left + Right == Sum (constant or constant splat vector) +static bool sumMatchConstant(SDValue Left, SDValue Right, unsigned Sum, + SelectionDAG &DAG, const SDLoc &DL) { + EVT ShiftVT = Left.getValueType(); + if (ShiftVT != Right.getValueType()) return false; + + SDValue ShiftSum = DAG.FoldConstantArithmetic(ISD::ADD, DL, ShiftVT, + Left.getNode(), Right.getNode()); + if (!ShiftSum) return false; + + ConstantSDNode *CSum = isConstOrConstSplat(ShiftSum); + return CSum && CSum->getZExtValue() == Sum; +} + // MatchRotate - Handle an 'or' of two operands. If this is one of the many // idioms for rotate, and if the target supports rotation instructions, generate // a rot[lr]. @@ -4630,30 +4644,24 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) { - uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue(); - uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue(); - if ((LShVal + RShVal) != EltSizeInBits) - return nullptr; - + if (sumMatchConstant(LHSShiftAmt, RHSShiftAmt, EltSizeInBits, DAG, DL)) { SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { - SDValue Mask = DAG.getAllOnesConstant(DL, VT); + SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); + SDValue Mask = AllOnes; if (LHSMask.getNode()) { - APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal); + SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt); Mask = DAG.getNode(ISD::AND, DL, VT, Mask, - DAG.getNode(ISD::OR, DL, VT, LHSMask, - DAG.getConstant(RHSBits, DL, VT))); + DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits)); } if (RHSMask.getNode()) { - APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal); + SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt); Mask = DAG.getNode(ISD::AND, DL, VT, Mask, - DAG.getNode(ISD::OR, DL, VT, RHSMask, - DAG.getConstant(LHSBits, DL, VT))); + DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits)); } Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask); |

