diff options
| author | Alexey Bataev <a.bataev@hotmail.com> | 2017-09-08 13:49:36 +0000 |
|---|---|---|
| committer | Alexey Bataev <a.bataev@hotmail.com> | 2017-09-08 13:49:36 +0000 |
| commit | 6dd29fccb881cd3f766effe22a0732d7f02b37fe (patch) | |
| tree | 0b6aaca1b921fa53193282087296a8fe9bf54d9f /llvm/lib | |
| parent | 46dfb7a39d7eb7e7b2439494ad18e31056aa6d09 (diff) | |
| download | bcm5719-llvm-6dd29fccb881cd3f766effe22a0732d7f02b37fe.tar.gz bcm5719-llvm-6dd29fccb881cd3f766effe22a0732d7f02b37fe.zip | |
[SLP] Support for horizontal min/max reduction.
SLP vectorizer supports horizontal reductions for Add/FAdd binary
operations. Patch adds support for horizontal min/max reductions.
Function getReductionCost() is split to getArithmeticReductionCost() for
binary operation reductions and getMinMaxReductionCost() for min/max
reductions.
Patch fixes PR26956.
Differential revision: https://reviews.llvm.org/D27846
llvm-svn: 312791
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Analysis/CostModel.cpp | 155 | ||||
| -rw-r--r-- | llvm/lib/Analysis/TargetTransformInfo.cpp | 9 | ||||
| -rw-r--r-- | llvm/lib/Target/X86/X86TargetTransformInfo.cpp | 146 | ||||
| -rw-r--r-- | llvm/lib/Target/X86/X86TargetTransformInfo.h | 3 | ||||
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 292 |
5 files changed, 507 insertions, 98 deletions
diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp index 071e23e90ff..47513f3c387 100644 --- a/llvm/lib/Analysis/CostModel.cpp +++ b/llvm/lib/Analysis/CostModel.cpp @@ -186,26 +186,56 @@ static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, } namespace { +/// Kind of the reduction data. +enum ReductionKind { + RK_None, /// Not a reduction. + RK_Arithmetic, /// Binary reduction data. + RK_MinMax, /// Min/max reduction data. + RK_UnsignedMinMax, /// Unsigned min/max reduction data. +}; /// Contains opcode + LHS/RHS parts of the reduction operations. struct ReductionData { - explicit ReductionData() = default; - ReductionData(unsigned Opcode, Value *LHS, Value *RHS) - : Opcode(Opcode), LHS(LHS), RHS(RHS) {} + ReductionData() = delete; + ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { + assert(Kind != RK_None && "expected binary or min/max reduction only."); + } unsigned Opcode = 0; Value *LHS = nullptr; Value *RHS = nullptr; + ReductionKind Kind = RK_None; + bool hasSameData(ReductionData &RD) const { + return Kind == RD.Kind && Opcode == RD.Opcode; + } }; } // namespace static Optional<ReductionData> getReductionData(Instruction *I) { Value *L, *R; if (m_BinOp(m_Value(L), m_Value(R)).match(I)) - return ReductionData(I->getOpcode(), L, R); + return ReductionData(RK_Arithmetic, I->getOpcode(), L, R); + if (auto *SI = dyn_cast<SelectInst>(I)) { + if (m_SMin(m_Value(L), m_Value(R)).match(SI) || + m_SMax(m_Value(L), m_Value(R)).match(SI) || + m_OrdFMin(m_Value(L), m_Value(R)).match(SI) || + m_OrdFMax(m_Value(L), m_Value(R)).match(SI) || + m_UnordFMin(m_Value(L), m_Value(R)).match(SI) || + m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) { + auto *CI = cast<CmpInst>(SI->getCondition()); + return ReductionData(RK_MinMax, CI->getOpcode(), L, R); + } + if (m_UMin(m_Value(L), m_Value(R)).match(SI) || + m_UMax(m_Value(L), m_Value(R)).match(SI)) { + auto *CI = cast<CmpInst>(SI->getCondition()); + return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R); + } + } return llvm::None; } -static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, - unsigned NumLevels) { +static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, + unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> @@ -213,24 +243,24 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 if (!I) - return false; + return RK_None; assert(I->getType()->isVectorTy() && "Expecting a vector type"); Optional<ReductionData> RD = getReductionData(I); if (!RD) - return false; + return RK_None; ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS); if (!LS && Level) - return false; + return RK_None; ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS); if (!RS && Level) - return false; + return RK_None; // On level 0 we can omit one shufflevector instruction. if (!Level && !RS && !LS) - return false; + return RK_None; // Shuffle inputs must match. Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; @@ -239,7 +269,7 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, if (NextLevelOpR && NextLevelOpL) { // If we have two shuffles their operands must match. if (NextLevelOpL != NextLevelOpR) - return false; + return RK_None; NextLevelOp = NextLevelOpL; } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { @@ -250,45 +280,47 @@ static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, // %NextLevelOpL = shufflevector %R, <1, undef ...> // %BinOp = fadd %NextLevelOpL, %R if (NextLevelOpL && NextLevelOpL != RD->RHS) - return false; + return RK_None; else if (NextLevelOpR && NextLevelOpR != RD->LHS) - return false; + return RK_None; NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; - } else - return false; + } else { + return RK_None; + } // Check that the next levels binary operation exists and matches with the // current one. if (Level + 1 != NumLevels) { Optional<ReductionData> NextLevelRD = getReductionData(cast<Instruction>(NextLevelOp)); - if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode) - return false; + if (!NextLevelRD || !RD->hasSameData(*NextLevelRD)) + return RK_None; } // Shuffle mask for pairwise operation must match. if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) - return false; + return RK_None; } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) - return false; - } else - return false; + return RK_None; + } else { + return RK_None; + } if (++Level == NumLevels) - return true; + return RD->Kind; // Match next level. return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level, NumLevels); } -static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, - unsigned &Opcode, Type *&Ty) { +static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, Type *&Ty) { if (!EnableReduxCost) - return false; + return RK_None; // Need to extract the first element. ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); @@ -296,19 +328,19 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, if (CI) Idx = CI->getZExtValue(); if (Idx != 0) - return false; + return RK_None; auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); if (!RdxStart) - return false; + return RK_None; Optional<ReductionData> RD = getReductionData(RdxStart); if (!RD) - return false; + return RK_None; Type *VecTy = RdxStart->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) - return false; + return RK_None; // We look for a sequence of shuffle,shuffle,add triples like the following // that builds a pairwise reduction tree. @@ -328,13 +360,14 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 // %r = extractelement <4 x float> %bin.rdx8, i32 0 - if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) - return false; + if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) == + RK_None) + return RK_None; Opcode = RD->Opcode; Ty = VecTy; - return true; + return RD->Kind; } static std::pair<Value *, ShuffleVectorInst *> @@ -348,10 +381,11 @@ getShuffleAndOtherOprd(Value *L, Value *R) { return std::make_pair(L, S); } -static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, - unsigned &Opcode, Type *&Ty) { +static ReductionKind +matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, + unsigned &Opcode, Type *&Ty) { if (!EnableReduxCost) - return false; + return RK_None; // Need to extract the first element. ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); @@ -359,19 +393,19 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, if (CI) Idx = CI->getZExtValue(); if (Idx != 0) - return false; + return RK_None; auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); if (!RdxStart) - return false; + return RK_None; Optional<ReductionData> RD = getReductionData(RdxStart); if (!RD) - return false; + return RK_None; Type *VecTy = ReduxRoot->getOperand(0)->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) - return false; + return RK_None; // We look for a sequence of shuffles and adds like the following matching one // fadd, shuffle vector pair at a time. @@ -391,10 +425,10 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, while (NumVecElemsRemain - 1) { // Check for the right reduction operation. if (!RdxOp) - return false; + return RK_None; Optional<ReductionData> RDLevel = getReductionData(RdxOp); - if (!RDLevel || RDLevel->Opcode != RD->Opcode) - return false; + if (!RDLevel || !RDLevel->hasSameData(*RD)) + return RK_None; Value *NextRdxOp; ShuffleVectorInst *Shuffle; @@ -403,9 +437,9 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) - return false; + return RK_None; if (Shuffle->getOperand(0) != NextRdxOp) - return false; + return RK_None; // Check that shuffle masks matches. for (unsigned j = 0; j != MaskStart; ++j) @@ -415,7 +449,7 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); if (ShuffleMask != Mask) - return false; + return RK_None; RdxOp = dyn_cast<Instruction>(NextRdxOp); NumVecElemsRemain /= 2; @@ -424,7 +458,7 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, Opcode = RD->Opcode; Ty = VecTy; - return true; + return RD->Kind; } unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { @@ -519,13 +553,36 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { unsigned ReduxOpCode; Type *ReduxType; - if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + case RK_Arithmetic: return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, /*IsPairwiseForm=*/false); + case RK_MinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/false); + case RK_UnsignedMinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/false, /*IsUnsigned=*/true); + case RK_None: + break; } - if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + + switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + case RK_Arithmetic: return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, /*IsPairwiseForm=*/true); + case RK_MinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/false); + case RK_UnsignedMinMax: + return TTI->getMinMaxReductionCost( + ReduxType, CmpInst::makeCmpResultType(ReduxType), + /*IsPairwiseForm=*/true, /*IsUnsigned=*/true); + case RK_None: + break; } return TTI->getVectorInstrCost(I->getOpcode(), diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index e09138168c9..8673b1b55d9 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -484,6 +484,15 @@ int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty, return Cost; } +int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy, + bool IsPairwiseForm, + bool IsUnsigned) const { + int Cost = + TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + unsigned TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const { return TTIImpl->getCostOfKeepingLiveOverCall(Tys); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 871a38d0014..79f192ce062 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1999,6 +1999,152 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy, return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise); } +int X86TTIImpl::getMinMaxReductionCost(Type *ValTy, Type *CondTy, + bool IsPairwise, bool IsUnsigned) { + std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy); + + MVT MTy = LT.second; + + int ISD; + if (ValTy->isIntOrIntVectorTy()) { + ISD = IsUnsigned ? ISD::UMIN : ISD::SMIN; + } else { + assert(ValTy->isFPOrFPVectorTy() && + "Expected float point or integer vector type."); + ISD = ISD::FMINNUM; + } + + // We use the Intel Architecture Code Analyzer(IACA) to measure the throughput + // and make it as the cost. + + static const CostTblEntry SSE42CostTblPairWise[] = { + {ISD::FMINNUM, MVT::v2f64, 3}, + {ISD::FMINNUM, MVT::v4f32, 2}, + {ISD::SMIN, MVT::v2i64, 7}, // The data reported by the IACA is "6.8" + {ISD::UMIN, MVT::v2i64, 8}, // The data reported by the IACA is "8.6" + {ISD::SMIN, MVT::v4i32, 1}, // The data reported by the IACA is "1.5" + {ISD::UMIN, MVT::v4i32, 2}, // The data reported by the IACA is "1.8" + {ISD::SMIN, MVT::v8i16, 2}, + {ISD::UMIN, MVT::v8i16, 2}, + }; + + static const CostTblEntry AVX1CostTblPairWise[] = { + {ISD::FMINNUM, MVT::v4f32, 1}, + {ISD::FMINNUM, MVT::v4f64, 1}, + {ISD::FMINNUM, MVT::v8f32, 2}, + {ISD::SMIN, MVT::v2i64, 3}, + {ISD::UMIN, MVT::v2i64, 3}, + {ISD::SMIN, MVT::v4i32, 1}, + {ISD::UMIN, MVT::v4i32, 1}, + {ISD::SMIN, MVT::v8i16, 1}, + {ISD::UMIN, MVT::v8i16, 1}, + {ISD::SMIN, MVT::v8i32, 3}, + {ISD::UMIN, MVT::v8i32, 3}, + }; + + static const CostTblEntry AVX2CostTblPairWise[] = { + {ISD::SMIN, MVT::v4i64, 2}, + {ISD::UMIN, MVT::v4i64, 2}, + {ISD::SMIN, MVT::v8i32, 1}, + {ISD::UMIN, MVT::v8i32, 1}, + {ISD::SMIN, MVT::v16i16, 1}, + {ISD::UMIN, MVT::v16i16, 1}, + {ISD::SMIN, MVT::v32i8, 2}, + {ISD::UMIN, MVT::v32i8, 2}, + }; + + static const CostTblEntry AVX512CostTblPairWise[] = { + {ISD::FMINNUM, MVT::v8f64, 1}, + {ISD::FMINNUM, MVT::v16f32, 2}, + {ISD::SMIN, MVT::v8i64, 2}, + {ISD::UMIN, MVT::v8i64, 2}, + {ISD::SMIN, MVT::v16i32, 1}, + {ISD::UMIN, MVT::v16i32, 1}, + }; + + static const CostTblEntry SSE42CostTblNoPairWise[] = { + {ISD::FMINNUM, MVT::v2f64, 3}, + {ISD::FMINNUM, MVT::v4f32, 3}, + {ISD::SMIN, MVT::v2i64, 7}, // The data reported by the IACA is "6.8" + {ISD::UMIN, MVT::v2i64, 9}, // The data reported by the IACA is "8.6" + {ISD::SMIN, MVT::v4i32, 1}, // The data reported by the IACA is "1.5" + {ISD::UMIN, MVT::v4i32, 2}, // The data reported by the IACA is "1.8" + {ISD::SMIN, MVT::v8i16, 1}, // The data reported by the IACA is "1.5" + {ISD::UMIN, MVT::v8i16, 2}, // The data reported by the IACA is "1.8" + }; + + static const CostTblEntry AVX1CostTblNoPairWise[] = { + {ISD::FMINNUM, MVT::v4f32, 1}, + {ISD::FMINNUM, MVT::v4f64, 1}, + {ISD::FMINNUM, MVT::v8f32, 1}, + {ISD::SMIN, MVT::v2i64, 3}, + {ISD::UMIN, MVT::v2i64, 3}, + {ISD::SMIN, MVT::v4i32, 1}, + {ISD::UMIN, MVT::v4i32, 1}, + {ISD::SMIN, MVT::v8i16, 1}, + {ISD::UMIN, MVT::v8i16, 1}, + {ISD::SMIN, MVT::v8i32, 2}, + {ISD::UMIN, MVT::v8i32, 2}, + }; + + static const CostTblEntry AVX2CostTblNoPairWise[] = { + {ISD::SMIN, MVT::v4i64, 1}, + {ISD::UMIN, MVT::v4i64, 1}, + {ISD::SMIN, MVT::v8i32, 1}, + {ISD::UMIN, MVT::v8i32, 1}, + {ISD::SMIN, MVT::v16i16, 1}, + {ISD::UMIN, MVT::v16i16, 1}, + {ISD::SMIN, MVT::v32i8, 1}, + {ISD::UMIN, MVT::v32i8, 1}, + }; + + static const CostTblEntry AVX512CostTblNoPairWise[] = { + {ISD::FMINNUM, MVT::v8f64, 1}, + {ISD::FMINNUM, MVT::v16f32, 2}, + {ISD::SMIN, MVT::v8i64, 1}, + {ISD::UMIN, MVT::v8i64, 1}, + {ISD::SMIN, MVT::v16i32, 1}, + {ISD::UMIN, MVT::v16i32, 1}, + }; + + if (IsPairwise) { + if (ST->hasAVX512()) + if (const auto *Entry = CostTableLookup(AVX512CostTblPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + + if (ST->hasAVX2()) + if (const auto *Entry = CostTableLookup(AVX2CostTblPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + + if (ST->hasAVX()) + if (const auto *Entry = CostTableLookup(AVX1CostTblPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + + if (ST->hasSSE42()) + if (const auto *Entry = CostTableLookup(SSE42CostTblPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + } else { + if (ST->hasAVX512()) + if (const auto *Entry = + CostTableLookup(AVX512CostTblNoPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + + if (ST->hasAVX2()) + if (const auto *Entry = CostTableLookup(AVX2CostTblNoPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + + if (ST->hasAVX()) + if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + + if (ST->hasSSE42()) + if (const auto *Entry = CostTableLookup(SSE42CostTblNoPairWise, ISD, MTy)) + return LT.first * Entry->Cost; + } + + return BaseT::getMinMaxReductionCost(ValTy, CondTy, IsPairwise, IsUnsigned); +} + /// \brief Calculate the cost of materializing a 64-bit value. This helper /// method might only calculate a fraction of a larger immediate. Therefore it /// is valid to return a cost of ZERO. diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h index a8edc46ed57..a7f500dc507 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -96,6 +96,9 @@ public: int getArithmeticReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm); + int getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwiseForm, + bool IsUnsigned); + int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, unsigned Alignment, unsigned AddressSpace); diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index b147445d716..53b1d871fa3 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -4627,11 +4627,17 @@ class HorizontalReduction { // Use map vector to make stable output. MapVector<Instruction *, Value *> ExtraArgs; + /// Kind of the reduction data. + enum ReductionKind { + RK_None, /// Not a reduction. + RK_Arithmetic, /// Binary reduction data. + RK_Min, /// Minimum reduction data. + RK_UMin, /// Unsigned minimum reduction data. + RK_Max, /// Maximum reduction data. + RK_UMax, /// Unsigned maximum reduction data. + }; /// Contains info about operation, like its opcode, left and right operands. - struct OperationData { - /// true if the operation is a reduced value, false if reduction operation. - bool IsReducedValue = false; - + class OperationData { /// Opcode of the instruction. unsigned Opcode = 0; @@ -4640,12 +4646,21 @@ class HorizontalReduction { /// Right operand of the reduction operation. Value *RHS = nullptr; + /// Kind of the reduction operation. + ReductionKind Kind = RK_None; + /// True if float point min/max reduction has no NaNs. + bool NoNaN = false; /// Checks if the reduction operation can be vectorized. bool isVectorizable() const { return LHS && RHS && - // We currently only support adds. - (Opcode == Instruction::Add || Opcode == Instruction::FAdd); + // We currently only support adds && min/max reductions. + ((Kind == RK_Arithmetic && + (Opcode == Instruction::Add || Opcode == Instruction::FAdd)) || + ((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && + (Kind == RK_Min || Kind == RK_Max)) || + (Opcode == Instruction::ICmp && + (Kind == RK_UMin || Kind == RK_UMax))); } public: @@ -4653,43 +4668,90 @@ class HorizontalReduction { /// Construction for reduced values. They are identified by opcode only and /// don't have associated LHS/RHS values. - explicit OperationData(Value *V) : IsReducedValue(true) { + explicit OperationData(Value *V) : Kind(RK_None) { if (auto *I = dyn_cast<Instruction>(V)) Opcode = I->getOpcode(); } - /// Constructor for binary reduction operations with opcode and its left and + /// Constructor for reduction operations with opcode and its left and /// right operands. - OperationData(unsigned Opcode, Value *LHS, Value *RHS) - : Opcode(Opcode), LHS(LHS), RHS(RHS) {} - + OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind, + bool NoNaN = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind), NoNaN(NoNaN) { + assert(Kind != RK_None && "One of the reduction operations is expected."); + } explicit operator bool() const { return Opcode; } /// Get the index of the first operand. unsigned getFirstOperandIndex() const { assert(!!*this && "The opcode is not set."); + switch (Kind) { + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + return 1; + case RK_Arithmetic: + case RK_None: + break; + } return 0; } /// Total number of operands in the reduction operation. unsigned getNumberOfOperands() const { - assert(!IsReducedValue && !!*this && LHS && RHS && + assert(Kind != RK_None && !!*this && LHS && RHS && "Expected reduction operation."); - return 2; + switch (Kind) { + case RK_Arithmetic: + return 2; + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + return 3; + case RK_None: + llvm_unreachable("Reduction kind is not set"); + } } /// Expected number of uses for reduction operations/reduced values. unsigned getRequiredNumberOfUses() const { - assert(!IsReducedValue && !!*this && LHS && RHS && + assert(Kind != RK_None && !!*this && LHS && RHS && "Expected reduction operation."); - return 1; + switch (Kind) { + case RK_Arithmetic: + return 1; + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + return 2; + case RK_None: + llvm_unreachable("Reduction kind is not set"); + } } /// Checks if instruction is associative and can be vectorized. bool isAssociative(Instruction *I) const { - assert(!IsReducedValue && *this && LHS && RHS && + assert(Kind != RK_None && *this && LHS && RHS && "Expected reduction operation."); - return I->isAssociative(); + switch (Kind) { + case RK_Arithmetic: + return I->isAssociative(); + case RK_Min: + case RK_Max: + return Opcode == Instruction::ICmp || + cast<Instruction>(I->getOperand(0))->hasUnsafeAlgebra(); + case RK_UMin: + case RK_UMax: + assert(Opcode == Instruction::ICmp && + "Only integer compare operation is expected."); + return true; + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); } /// Checks if the reduction operation can be vectorized. @@ -4700,18 +4762,17 @@ class HorizontalReduction { /// Checks if two operation data are both a reduction op or both a reduced /// value. bool operator==(const OperationData &OD) { - assert(((IsReducedValue != OD.IsReducedValue) || - ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && + assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && "One of the comparing operations is incorrect."); - return this == &OD || - (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode); + return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode); } bool operator!=(const OperationData &OD) { return !(*this == OD); } void clear() { - IsReducedValue = false; Opcode = 0; LHS = nullptr; RHS = nullptr; + Kind = RK_None; + NoNaN = false; } /// Get the opcode of the reduction operation. @@ -4720,16 +4781,81 @@ class HorizontalReduction { return Opcode; } + /// Get kind of reduction data. + ReductionKind getKind() const { return Kind; } Value *getLHS() const { return LHS; } Value *getRHS() const { return RHS; } + Type *getConditionType() const { + switch (Kind) { + case RK_Arithmetic: + return nullptr; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: + return CmpInst::makeCmpResultType(LHS->getType()); + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); + } /// Creates reduction operation with the current opcode. Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const { - assert(!IsReducedValue && - (Opcode == Instruction::FAdd || Opcode == Instruction::Add) && - "Expected add|fadd reduction operation."); - return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, - Name); + assert(isVectorizable() && + "Expected add|fadd or min/max reduction operation."); + Value *Cmp; + switch (Kind) { + case RK_Arithmetic: + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, + Name); + case RK_Min: + Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS) + : Builder.CreateFCmpOLT(LHS, RHS); + break; + case RK_Max: + Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS) + : Builder.CreateFCmpOGT(LHS, RHS); + break; + case RK_UMin: + assert(Opcode == Instruction::ICmp && "Expected integer types."); + Cmp = Builder.CreateICmpULT(LHS, RHS); + break; + case RK_UMax: + assert(Opcode == Instruction::ICmp && "Expected integer types."); + Cmp = Builder.CreateICmpUGT(LHS, RHS); + break; + case RK_None: + llvm_unreachable("Unknown reduction operation."); + } + return Builder.CreateSelect(Cmp, LHS, RHS, Name); + } + TargetTransformInfo::ReductionFlags getFlags() const { + TargetTransformInfo::ReductionFlags Flags; + Flags.NoNaN = NoNaN; + switch (Kind) { + case RK_Arithmetic: + break; + case RK_Min: + Flags.IsSigned = Opcode == Instruction::ICmp; + Flags.IsMaxOp = false; + break; + case RK_Max: + Flags.IsSigned = Opcode == Instruction::ICmp; + Flags.IsMaxOp = true; + break; + case RK_UMin: + Flags.IsSigned = false; + Flags.IsMaxOp = false; + break; + case RK_UMax: + Flags.IsSigned = false; + Flags.IsMaxOp = true; + break; + case RK_None: + llvm_unreachable("Reduction kind is not set"); + } + return Flags; } }; @@ -4771,8 +4897,32 @@ class HorizontalReduction { Value *LHS; Value *RHS; - if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) - return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS); + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) { + return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS, + RK_Arithmetic); + } + if (auto *Select = dyn_cast<SelectInst>(V)) { + // Look for a min/max pattern. + if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_Min); + } else if (m_OrdFMin(m_Value(LHS), m_Value(RHS)).match(Select) || + m_UnordFMin(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData( + Instruction::FCmp, LHS, RHS, RK_Min, + cast<Instruction>(Select->getCondition())->hasNoNaNs()); + } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_Max); + } else if (m_OrdFMax(m_Value(LHS), m_Value(RHS)).match(Select) || + m_UnordFMax(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData( + Instruction::FCmp, LHS, RHS, RK_Max, + cast<Instruction>(Select->getCondition())->hasNoNaNs()); + } + } return OperationData(V); } @@ -4965,8 +5115,9 @@ public: if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, ReducedSubTree); - VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx"); + VectorizedTree, ReducedSubTree, + ReductionData.getKind()); + VectorizedTree = VectReductionData.createOp(Builder, "op.rdx"); propagateIRFlags(VectorizedTree, ReductionOps); } else VectorizedTree = ReducedSubTree; @@ -4980,7 +5131,8 @@ public: auto *I = cast<Instruction>(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, I); + VectorizedTree, I, + ReductionData.getKind()); VectorizedTree = VectReductionData.createOp(Builder); propagateIRFlags(VectorizedTree, ReductionOps); } @@ -4991,8 +5143,9 @@ public: for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, Pair.first); - VectorizedTree = VectReductionData.createOp(Builder, "bin.extra"); + VectorizedTree, Pair.first, + ReductionData.getKind()); + VectorizedTree = VectReductionData.createOp(Builder, "op.extra"); propagateIRFlags(VectorizedTree, I); } } @@ -5013,19 +5166,58 @@ private: Type *ScalarTy = FirstReducedVal->getType(); Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); - int PairwiseRdxCost = - TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, - /*IsPairwiseForm=*/true); - int SplittingRdxCost = - TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, - /*IsPairwiseForm=*/false); + int PairwiseRdxCost; + int SplittingRdxCost; + bool IsUnsigned = true; + switch (ReductionData.getKind()) { + case RK_Arithmetic: + PairwiseRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); + SplittingRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); + break; + case RK_Min: + case RK_Max: + IsUnsigned = false; + case RK_UMin: + case RK_UMax: { + Type *VecCondTy = CmpInst::makeCmpResultType(VecTy); + PairwiseRdxCost = + TTI->getMinMaxReductionCost(VecTy, VecCondTy, + /*IsPairwiseForm=*/true, IsUnsigned); + SplittingRdxCost = + TTI->getMinMaxReductionCost(VecTy, VecCondTy, + /*IsPairwiseForm=*/false, IsUnsigned); + break; + } + case RK_None: + llvm_unreachable("Expected arithmetic or min/max reduction operation"); + } IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; - int ScalarReduxCost = - (ReduxWidth - 1) * - TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); + int ScalarReduxCost; + switch (ReductionData.getKind()) { + case RK_Arithmetic: + ScalarReduxCost = + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); + break; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: + ScalarReduxCost = + TTI->getCmpSelInstrCost(ReductionData.getOpcode(), ScalarTy) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, + CmpInst::makeCmpResultType(ScalarTy)); + break; + case RK_None: + llvm_unreachable("Expected arithmetic or min/max reduction operation"); + } + ScalarReduxCost *= (ReduxWidth - 1); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -5047,7 +5239,7 @@ private: if (!IsPairwiseReduction) return createSimpleTargetReduction( Builder, TTI, ReductionData.getOpcode(), VectorizedValue, - TargetTransformInfo::ReductionFlags(), RedOps); + ReductionData.getFlags(), RedOps); Value *TmpVec = VectorizedValue; for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) { @@ -5062,8 +5254,8 @@ private: TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf, - RightShuf); - TmpVec = VectReductionData.createOp(Builder, "bin.rdx"); + RightShuf, ReductionData.getKind()); + TmpVec = VectReductionData.createOp(Builder, "op.rdx"); propagateIRFlags(TmpVec, RedOps); } @@ -5224,9 +5416,11 @@ static bool tryToVectorizeHorReductionOrInstOperands( auto *Inst = dyn_cast<Instruction>(V); if (!Inst) continue; - if (auto *BI = dyn_cast<BinaryOperator>(Inst)) { + auto *BI = dyn_cast<BinaryOperator>(Inst); + auto *SI = dyn_cast<SelectInst>(Inst); + if (BI || SI) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, BI)) { + if (HorRdx.matchAssociativeReduction(P, Inst)) { if (HorRdx.tryToReduce(R, TTI)) { Res = true; // Set P to nullptr to avoid re-analysis of phi node in @@ -5235,7 +5429,7 @@ static bool tryToVectorizeHorReductionOrInstOperands( continue; } } - if (P) { + if (P && BI) { Inst = dyn_cast<Instruction>(BI->getOperand(0)); if (Inst == P) Inst = dyn_cast<Instruction>(BI->getOperand(1)); |

