diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index ee9d13c5c2c..014e67274b8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9005,7 +9005,8 @@ void SDNode::intersectFlagsWith(const SDNodeFlags Flags) { SDValue SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, - ArrayRef<ISD::NodeType> CandidateBinOps) { + ArrayRef<ISD::NodeType> CandidateBinOps, + bool AllowPartials) { // The pattern must end in an extract from index 0. if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT || !isNullConstant(Extract->getOperand(1))) @@ -9019,6 +9020,23 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, return Op.getOpcode() == unsigned(BinOp); })) return SDValue(); + unsigned CandidateBinOp = Op.getOpcode(); + + // Matching failed - attempt to see if we did enough stages that a partial + // reduction from a subvector is possible. + auto PartialReduction = [&](SDValue Op, unsigned NumSubElts) { + if (!AllowPartials || !Op) + return SDValue(); + EVT OpVT = Op.getValueType(); + EVT OpSVT = OpVT.getScalarType(); + EVT SubVT = EVT::getVectorVT(*getContext(), OpSVT, NumSubElts); + if (!TLI->isExtractSubvectorCheap(SubVT, OpVT, 0)) + return SDValue(); + BinOp = (ISD::NodeType)CandidateBinOp; + return getNode( + ISD::EXTRACT_SUBVECTOR, SDLoc(Op), SubVT, Op, + getConstant(0, SDLoc(Op), TLI->getVectorIdxTy(getDataLayout()))); + }; // At each stage, we're looking for something that looks like: // %s = shufflevector <8 x i32> %op, <8 x i32> undef, @@ -9030,10 +9048,15 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, // <4,5,6,7,u,u,u,u> // <2,3,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u> - unsigned CandidateBinOp = Op.getOpcode(); + // While a partial reduction match would be: + // <2,3,u,u,u,u,u,u> + // <1,u,u,u,u,u,u,u> + SDValue PrevOp; for (unsigned i = 0; i < Stages; ++i) { + unsigned MaskEnd = (1 << i); + if (Op.getOpcode() != CandidateBinOp) - return SDValue(); + return PartialReduction(PrevOp, MaskEnd); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -9049,12 +9072,14 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, // The first operand of the shuffle should be the same as the other operand // of the binop. if (!Shuffle || Shuffle->getOperand(0) != Op) - return SDValue(); + return PartialReduction(PrevOp, MaskEnd); // Verify the shuffle has the expected (at this stage of the pyramid) mask. - for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) - if (Shuffle->getMaskElt(Index) != MaskEnd + Index) - return SDValue(); + for (int Index = 0; Index < (int)MaskEnd; ++Index) + if (Shuffle->getMaskElt(Index) != (MaskEnd + Index)) + return PartialReduction(PrevOp, MaskEnd); + + PrevOp = Op; } BinOp = (ISD::NodeType)CandidateBinOp; |