diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 62499a28dff..59540211d54 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35624,6 +35624,57 @@ static SDValue scalarizeExtEltFP(SDNode *ExtElt, SelectionDAG &DAG) { llvm_unreachable("All opcodes should return within switch"); } +/// Try to convert a vector reduction sequence composed of binops and shuffles +/// into horizontal ops. +static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert(ExtElt->getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unexpected caller"); + bool OptForSize = DAG.getMachineFunction().getFunction().hasOptSize(); + if (!Subtarget.hasFastHorizontalOps() && !OptForSize) + return SDValue(); + SDValue Index = ExtElt->getOperand(1); + if (!isNullConstant(Index)) + return SDValue(); + + // TODO: Allow FADD with reduction and/or reassociation and no-signed-zeros. + ISD::NodeType Opc; + SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}); + if (!Rdx) + return SDValue(); + + EVT VT = ExtElt->getValueType(0); + EVT VecVT = ExtElt->getOperand(0).getValueType(); + if (VecVT.getScalarType() != VT) + return SDValue(); + + unsigned HorizOpcode = Opc == ISD::ADD ? X86ISD::HADD : X86ISD::FHADD; + SDLoc DL(ExtElt); + + // 256-bit horizontal instructions operate on 128-bit chunks rather than + // across the whole vector, so we need an extract + hop preliminary stage. + // This is the only step where the operands of the hop are not the same value. + // TODO: We could extend this to handle 512-bit or even longer vectors. + if (((VecVT == MVT::v16i16 || VecVT == MVT::v8i32) && Subtarget.hasSSSE3()) || + ((VecVT == MVT::v8f32 || VecVT == MVT::v4f64) && Subtarget.hasSSE3())) { + unsigned NumElts = VecVT.getVectorNumElements(); + SDValue Hi = extract128BitVector(Rdx, NumElts / 2, DAG, DL); + SDValue Lo = extract128BitVector(Rdx, 0, DAG, DL); + VecVT = EVT::getVectorVT(*DAG.getContext(), VT, NumElts / 2); + Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Hi, Lo); + } + if (!((VecVT == MVT::v8i16 || VecVT == MVT::v4i32) && Subtarget.hasSSSE3()) && + !((VecVT == MVT::v4f32 || VecVT == MVT::v2f64) && Subtarget.hasSSE3())) + return SDValue(); + + // extract (add (shuf X), X), 0 --> extract (hadd X, X), 0 + assert(Rdx.getValueType() == VecVT && "Unexpected reduction match"); + unsigned ReductionSteps = Log2_32(VecVT.getVectorNumElements()); + for (unsigned i = 0; i != ReductionSteps; ++i) + Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Rdx, Rdx); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index); +} + /// Detect vector gather/scatter index generation and convert it from being a /// bunch of shuffles and extracts into a somewhat faster sequence. /// For i686, the best sequence is apparently storing the value and loading @@ -35710,6 +35761,9 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget)) return MinMax; + if (SDValue V = combineReductionToHorizontal(N, DAG, Subtarget)) + return V; + if (SDValue V = scalarizeExtEltFP(N, DAG)) return V; |