diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0be5460fa34..f777e562898 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35065,6 +35065,32 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); } +/// Convert vector increment or decrement to sub/add with an all-ones constant: +/// add X, <1, 1...> --> sub X, <-1, -1...> +/// sub X, <1, 1...> --> add X, <-1, -1...> +/// The all-ones vector constant can be materialized using a pcmpeq instruction +/// that is commonly recognized as an idiom (has no register dependency), so +/// that's better/smaller than loading a splat 1 constant. +static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB && + "Unexpected opcode for increment/decrement transform"); + + // Pseudo-legality check: getOnesVector() expects one of these types, so bail + // out and wait for legalization if we have an unsupported vector length. + EVT VT = N->getValueType(0); + if (!VT.is128BitVector() && !VT.is256BitVector() && !VT.is512BitVector()) + return SDValue(); + + SDNode *N1 = N->getOperand(1).getNode(); + APInt SplatVal; + if (!ISD::isConstantSplatVector(N1, SplatVal) || !SplatVal.isOneValue()) + return SDValue(); + + SDValue AllOnesVec = getOnesVector(VT, DAG, SDLoc(N)); + unsigned NewOpcode = N->getOpcode() == ISD::ADD ? ISD::SUB : ISD::ADD; + return DAG.getNode(NewOpcode, SDLoc(N), VT, N->getOperand(0), AllOnesVec); +} + static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags Flags = N->getFlags(); @@ -35084,6 +35110,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, isHorizontalBinOp(Op0, Op1, true)) return DAG.getNode(X86ISD::HADD, SDLoc(N), VT, Op0, Op1); + if (SDValue V = combineIncDecVector(N, DAG)) + return V; + return combineAddOrSubToADCOrSBB(N, DAG); } @@ -35117,6 +35146,9 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG, isHorizontalBinOp(Op0, Op1, false)) return DAG.getNode(X86ISD::HSUB, SDLoc(N), VT, Op0, Op1); + if (SDValue V = combineIncDecVector(N, DAG)) + return V; + return combineAddOrSubToADCOrSBB(N, DAG); } |