diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2018-02-22 18:45:13 +0000 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2018-02-22 18:45:13 +0000 |
commit | be72fe1fda2de77b23b16c9b53186c8151752097 (patch) | |
tree | e570f15f6fca9d76ece2d5902ac23139fa1e388b /llvm/lib | |
parent | df38f155eca55b5123cf9fcc3d18686486ba7048 (diff) | |
download | bcm5719-llvm-be72fe1fda2de77b23b16c9b53186c8151752097.tar.gz bcm5719-llvm-be72fe1fda2de77b23b16c9b53186c8151752097.zip |
[SelectionDAG] Move matchUnaryPredicate/matchBinaryPredicate into SelectionDAGNodes.h
This allows us to improve vector constant matching in more DAG code (backends, TargetLowering etc.).
Differential Revision: https://reviews.llvm.org/D43466
llvm-svn: 325815
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 74 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 46 |
2 files changed, 58 insertions, 62 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 158a020d443..3f99843406a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -920,56 +920,6 @@ static bool isAnyConstantBuildVector(const SDNode *N) { ISD::isBuildVectorOfConstantFPSDNodes(N); } -// Attempt to match a unary predicate against a scalar/splat constant or -// every element of a constant BUILD_VECTOR. -static bool matchUnaryPredicate(SDValue Op, - std::function<bool(ConstantSDNode *)> Match) { - if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) - return Match(Cst); - - if (ISD::BUILD_VECTOR != Op.getOpcode()) - return false; - - EVT SVT = Op.getValueType().getScalarType(); - for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { - auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i)); - if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst)) - return false; - } - return true; -} - -// Attempt to match a binary predicate against a pair of scalar/splat constants -// or every element of a pair of constant BUILD_VECTORs. -static bool matchBinaryPredicate( - SDValue LHS, SDValue RHS, - std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match) { - if (LHS.getValueType() != RHS.getValueType()) - return false; - - if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS)) - if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS)) - return Match(LHSCst, RHSCst); - - if (ISD::BUILD_VECTOR != LHS.getOpcode() || - ISD::BUILD_VECTOR != RHS.getOpcode()) - return false; - - EVT SVT = LHS.getValueType().getScalarType(); - for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { - auto *LHSCst = dyn_cast<ConstantSDNode>(LHS.getOperand(i)); - auto *RHSCst = dyn_cast<ConstantSDNode>(RHS.getOperand(i)); - if (!LHSCst || !RHSCst) - return false; - if (LHSCst->getValueType(0) != SVT || - LHSCst->getValueType(0) != RHSCst->getValueType(0)) - return false; - if (!Match(LHSCst, RHSCst)) - return false; - } - return true; -} - SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); @@ -4067,7 +4017,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue()); }; if (N0.getOpcode() == ISD::OR && - matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset)) + ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset)) return N1; // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits. if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) { @@ -4756,7 +4706,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return LHS->getAPIntValue().intersects(RHS->getAPIntValue()); }; if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && - matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) { + ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) { if (SDValue COR = DAG.FoldConstantArithmetic( ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) { SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1); @@ -4991,7 +4941,7 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { ConstantSDNode *RHS) { return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; }; - if (matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { + if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); @@ -5704,7 +5654,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { return Val->getAPIntValue().uge(OpSizeInBits); }; - if (matchUnaryPredicate(N1, MatchShiftTooBig)) + if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (shl x, 0) -> x if (N1C && N1C->isNullValue()) @@ -5739,7 +5689,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).uge(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) return DAG.getConstant(0, SDLoc(N), VT); auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, @@ -5749,7 +5699,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).ult(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); @@ -5925,7 +5875,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { return Val->getAPIntValue().uge(OpSizeInBits); }; - if (matchUnaryPredicate(N1, MatchShiftTooBig)) + if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (sra x, 0) -> x if (N1C && N1C->isNullValue()) @@ -5960,7 +5910,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).uge(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT)); @@ -5971,7 +5921,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).ult(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum); } @@ -6089,7 +6039,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { return Val->getAPIntValue().uge(OpSizeInBits); }; - if (matchUnaryPredicate(N1, MatchShiftTooBig)) + if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (srl x, 0) -> x if (N1C && N1C->isNullValue()) @@ -6112,7 +6062,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).uge(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) return DAG.getConstant(0, SDLoc(N), VT); auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, @@ -6122,7 +6072,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).ult(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 11661eb6262..87828a722b8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -263,6 +263,52 @@ bool ISD::allOperandsUndef(const SDNode *N) { return true; } +bool ISD::matchUnaryPredicate(SDValue Op, + std::function<bool(ConstantSDNode *)> Match) { + if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) + return Match(Cst); + + if (ISD::BUILD_VECTOR != Op.getOpcode()) + return false; + + EVT SVT = Op.getValueType().getScalarType(); + for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { + auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i)); + if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst)) + return false; + } + return true; +} + +bool ISD::matchBinaryPredicate( + SDValue LHS, SDValue RHS, + std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match) { + if (LHS.getValueType() != RHS.getValueType()) + return false; + + if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS)) + if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS)) + return Match(LHSCst, RHSCst); + + if (ISD::BUILD_VECTOR != LHS.getOpcode() || + ISD::BUILD_VECTOR != RHS.getOpcode()) + return false; + + EVT SVT = LHS.getValueType().getScalarType(); + for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { + auto *LHSCst = dyn_cast<ConstantSDNode>(LHS.getOperand(i)); + auto *RHSCst = dyn_cast<ConstantSDNode>(RHS.getOperand(i)); + if (!LHSCst || !RHSCst) + return false; + if (LHSCst->getValueType(0) != SVT || + LHSCst->getValueType(0) != RHSCst->getValueType(0)) + return false; + if (!Match(LHSCst, RHSCst)) + return false; + } + return true; +} + ISD::NodeType ISD::getExtForLoadExtType(bool IsFP, ISD::LoadExtType ExtType) { switch (ExtType) { case ISD::EXTLOAD: |