diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 294 |
1 files changed, 49 insertions, 245 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 03b6a8df33b..0e47dc1a407 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -4653,17 +4653,11 @@ class HorizontalReduction { // Use map vector to make stable output. MapVector<Instruction *, Value *> ExtraArgs; - /// Kind of the reduction data. - enum ReductionKind { - RK_None, /// Not a reduction. - RK_Arithmetic, /// Binary reduction data. - RK_Min, /// Minimum reduction data. - RK_UMin, /// Unsigned minimum reduction data. - RK_Max, /// Maximum reduction data. - RK_UMax, /// Unsigned maximum reduction data. - }; /// Contains info about operation, like its opcode, left and right operands. - class OperationData { + struct OperationData { + /// true if the operation is a reduced value, false if reduction operation. + bool IsReducedValue = false; + /// Opcode of the instruction. unsigned Opcode = 0; @@ -4672,21 +4666,12 @@ class HorizontalReduction { /// Right operand of the reduction operation. Value *RHS = nullptr; - /// Kind of the reduction operation. - ReductionKind Kind = RK_None; - /// True if float point min/max reduction has no NaNs. - bool NoNaN = false; /// Checks if the reduction operation can be vectorized. bool isVectorizable() const { return LHS && RHS && - // We currently only support adds && min/max reductions. - ((Kind == RK_Arithmetic && - (Opcode == Instruction::Add || Opcode == Instruction::FAdd)) || - ((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && - (Kind == RK_Min || Kind == RK_Max)) || - (Opcode == Instruction::ICmp && - (Kind == RK_UMin || Kind == RK_UMax))); + // We currently only support adds. + (Opcode == Instruction::Add || Opcode == Instruction::FAdd); } public: @@ -4694,92 +4679,43 @@ class HorizontalReduction { /// Construction for reduced values. They are identified by opcode only and /// don't have associated LHS/RHS values. - explicit OperationData(Value *V) : Kind(RK_None) { + explicit OperationData(Value *V) : IsReducedValue(true) { if (auto *I = dyn_cast<Instruction>(V)) Opcode = I->getOpcode(); } - /// Constructor for reduction operations with opcode and its left and + /// Constructor for binary reduction operations with opcode and its left and /// right operands. - OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind, - bool NoNaN = false) - : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind), NoNaN(NoNaN) { - assert(Kind != RK_None && "One of the reduction operations is expected."); - } + OperationData(unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS) {} + explicit operator bool() const { return Opcode; } /// Get the index of the first operand. unsigned getFirstOperandIndex() const { assert(!!*this && "The opcode is not set."); - switch (Kind) { - case RK_Min: - case RK_UMin: - case RK_Max: - case RK_UMax: - return 1; - case RK_Arithmetic: - case RK_None: - break; - } return 0; } /// Total number of operands in the reduction operation. unsigned getNumberOfOperands() const { - assert(Kind != RK_None && !!*this && LHS && RHS && + assert(!IsReducedValue && !!*this && LHS && RHS && "Expected reduction operation."); - switch (Kind) { - case RK_Arithmetic: - return 2; - case RK_Min: - case RK_UMin: - case RK_Max: - case RK_UMax: - return 3; - case RK_None: - break; - } - llvm_unreachable("Reduction kind is not set"); + return 2; } /// Expected number of uses for reduction operations/reduced values. unsigned getRequiredNumberOfUses() const { - assert(Kind != RK_None && !!*this && LHS && RHS && + assert(!IsReducedValue && !!*this && LHS && RHS && "Expected reduction operation."); - switch (Kind) { - case RK_Arithmetic: - return 1; - case RK_Min: - case RK_UMin: - case RK_Max: - case RK_UMax: - return 2; - case RK_None: - break; - } - llvm_unreachable("Reduction kind is not set"); + return 1; } /// Checks if instruction is associative and can be vectorized. bool isAssociative(Instruction *I) const { - assert(Kind != RK_None && *this && LHS && RHS && + assert(!IsReducedValue && *this && LHS && RHS && "Expected reduction operation."); - switch (Kind) { - case RK_Arithmetic: - return I->isAssociative(); - case RK_Min: - case RK_Max: - return Opcode == Instruction::ICmp || - cast<Instruction>(I->getOperand(0))->hasUnsafeAlgebra(); - case RK_UMin: - case RK_UMax: - assert(Opcode == Instruction::ICmp && - "Only integer compare operation is expected."); - return true; - case RK_None: - break; - } - llvm_unreachable("Reduction kind is not set"); + return I->isAssociative(); } /// Checks if the reduction operation can be vectorized. @@ -4790,17 +4726,18 @@ class HorizontalReduction { /// Checks if two operation data are both a reduction op or both a reduced /// value. bool operator==(const OperationData &OD) { - assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && + assert(((IsReducedValue != OD.IsReducedValue) || + ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && "One of the comparing operations is incorrect."); - return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode); + return this == &OD || + (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode); } bool operator!=(const OperationData &OD) { return !(*this == OD); } void clear() { + IsReducedValue = false; Opcode = 0; LHS = nullptr; RHS = nullptr; - Kind = RK_None; - NoNaN = false; } /// Get the opcode of the reduction operation. @@ -4809,81 +4746,16 @@ class HorizontalReduction { return Opcode; } - /// Get kind of reduction data. - ReductionKind getKind() const { return Kind; } Value *getLHS() const { return LHS; } Value *getRHS() const { return RHS; } - Type *getConditionType() const { - switch (Kind) { - case RK_Arithmetic: - return nullptr; - case RK_Min: - case RK_Max: - case RK_UMin: - case RK_UMax: - return CmpInst::makeCmpResultType(LHS->getType()); - case RK_None: - break; - } - llvm_unreachable("Reduction kind is not set"); - } /// Creates reduction operation with the current opcode. Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const { - assert(isVectorizable() && - "Expected add|fadd or min/max reduction operation."); - Value *Cmp; - switch (Kind) { - case RK_Arithmetic: - return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, - Name); - case RK_Min: - Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS) - : Builder.CreateFCmpOLT(LHS, RHS); - break; - case RK_Max: - Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS) - : Builder.CreateFCmpOGT(LHS, RHS); - break; - case RK_UMin: - assert(Opcode == Instruction::ICmp && "Expected integer types."); - Cmp = Builder.CreateICmpULT(LHS, RHS); - break; - case RK_UMax: - assert(Opcode == Instruction::ICmp && "Expected integer types."); - Cmp = Builder.CreateICmpUGT(LHS, RHS); - break; - case RK_None: - llvm_unreachable("Unknown reduction operation."); - } - return Builder.CreateSelect(Cmp, LHS, RHS, Name); - } - TargetTransformInfo::ReductionFlags getFlags() const { - TargetTransformInfo::ReductionFlags Flags; - Flags.NoNaN = NoNaN; - switch (Kind) { - case RK_Arithmetic: - break; - case RK_Min: - Flags.IsSigned = Opcode == Instruction::ICmp; - Flags.IsMaxOp = false; - break; - case RK_Max: - Flags.IsSigned = Opcode == Instruction::ICmp; - Flags.IsMaxOp = true; - break; - case RK_UMin: - Flags.IsSigned = false; - Flags.IsMaxOp = false; - break; - case RK_UMax: - Flags.IsSigned = false; - Flags.IsMaxOp = true; - break; - case RK_None: - llvm_unreachable("Reduction kind is not set"); - } - return Flags; + assert(!IsReducedValue && + (Opcode == Instruction::FAdd || Opcode == Instruction::Add) && + "Expected add|fadd reduction operation."); + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, + Name); } }; @@ -4925,32 +4797,8 @@ class HorizontalReduction { Value *LHS; Value *RHS; - if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) { - return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS, - RK_Arithmetic); - } - if (auto *Select = dyn_cast<SelectInst>(V)) { - // Look for a min/max pattern. - if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); - } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_Min); - } else if (m_OrdFMin(m_Value(LHS), m_Value(RHS)).match(Select) || - m_UnordFMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData( - Instruction::FCmp, LHS, RHS, RK_Min, - cast<Instruction>(Select->getCondition())->hasNoNaNs()); - } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); - } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_Max); - } else if (m_OrdFMax(m_Value(LHS), m_Value(RHS)).match(Select) || - m_UnordFMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData( - Instruction::FCmp, LHS, RHS, RK_Max, - cast<Instruction>(Select->getCondition())->hasNoNaNs()); - } - } + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) + return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS); return OperationData(V); } @@ -5143,9 +4991,8 @@ public: if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, ReducedSubTree, - ReductionData.getKind()); - VectorizedTree = VectReductionData.createOp(Builder, "op.rdx"); + VectorizedTree, ReducedSubTree); + VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx"); propagateIRFlags(VectorizedTree, ReductionOps); } else VectorizedTree = ReducedSubTree; @@ -5159,8 +5006,7 @@ public: auto *I = cast<Instruction>(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, I, - ReductionData.getKind()); + VectorizedTree, I); VectorizedTree = VectReductionData.createOp(Builder); propagateIRFlags(VectorizedTree, ReductionOps); } @@ -5171,9 +5017,8 @@ public: for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, Pair.first, - ReductionData.getKind()); - VectorizedTree = VectReductionData.createOp(Builder, "op.extra"); + VectorizedTree, Pair.first); + VectorizedTree = VectReductionData.createOp(Builder, "bin.extra"); propagateIRFlags(VectorizedTree, I); } } @@ -5194,58 +5039,19 @@ private: Type *ScalarTy = FirstReducedVal->getType(); Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); - int PairwiseRdxCost; - int SplittingRdxCost; - switch (ReductionData.getKind()) { - case RK_Arithmetic: - PairwiseRdxCost = - TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, - /*IsPairwiseForm=*/true); - SplittingRdxCost = - TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, - /*IsPairwiseForm=*/false); - break; - case RK_Min: - case RK_Max: - case RK_UMin: - case RK_UMax: { - Type *VecCondTy = CmpInst::makeCmpResultType(VecTy); - bool IsUnsigned = ReductionData.getKind() == RK_UMin || - ReductionData.getKind() == RK_UMax; - PairwiseRdxCost = - TTI->getMinMaxReductionCost(VecTy, VecCondTy, - /*IsPairwiseForm=*/true, IsUnsigned); - SplittingRdxCost = - TTI->getMinMaxReductionCost(VecTy, VecCondTy, - /*IsPairwiseForm=*/false, IsUnsigned); - break; - } - case RK_None: - llvm_unreachable("Expected arithmetic or min/max reduction operation"); - } + int PairwiseRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); + int SplittingRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; - int ScalarReduxCost; - switch (ReductionData.getKind()) { - case RK_Arithmetic: - ScalarReduxCost = - TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); - break; - case RK_Min: - case RK_Max: - case RK_UMin: - case RK_UMax: - ScalarReduxCost = - TTI->getCmpSelInstrCost(ReductionData.getOpcode(), ScalarTy) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); - break; - case RK_None: - llvm_unreachable("Expected arithmetic or min/max reduction operation"); - } - ScalarReduxCost *= (ReduxWidth - 1); + int ScalarReduxCost = + (ReduxWidth - 1) * + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -5267,7 +5073,7 @@ private: if (!IsPairwiseReduction) return createSimpleTargetReduction( Builder, TTI, ReductionData.getOpcode(), VectorizedValue, - ReductionData.getFlags(), RedOps); + TargetTransformInfo::ReductionFlags(), RedOps); Value *TmpVec = VectorizedValue; for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) { @@ -5282,8 +5088,8 @@ private: TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf, - RightShuf, ReductionData.getKind()); - TmpVec = VectReductionData.createOp(Builder, "op.rdx"); + RightShuf); + TmpVec = VectReductionData.createOp(Builder, "bin.rdx"); propagateIRFlags(TmpVec, RedOps); } @@ -5444,11 +5250,9 @@ static bool tryToVectorizeHorReductionOrInstOperands( auto *Inst = dyn_cast<Instruction>(V); if (!Inst) continue; - auto *BI = dyn_cast<BinaryOperator>(Inst); - auto *SI = dyn_cast<SelectInst>(Inst); - if (BI || SI) { + if (auto *BI = dyn_cast<BinaryOperator>(Inst)) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, Inst)) { + if (HorRdx.matchAssociativeReduction(P, BI)) { if (HorRdx.tryToReduce(R, TTI)) { Res = true; // Set P to nullptr to avoid re-analysis of phi node in @@ -5457,7 +5261,7 @@ static bool tryToVectorizeHorReductionOrInstOperands( continue; } } - if (P && BI) { + if (P) { Inst = dyn_cast<Instruction>(BI->getOperand(0)); if (Inst == P) Inst = dyn_cast<Instruction>(BI->getOperand(1)); |