summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp')
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp294
1 files changed, 49 insertions, 245 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 03b6a8df33b..0e47dc1a407 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4653,17 +4653,11 @@ class HorizontalReduction {
// Use map vector to make stable output.
MapVector<Instruction *, Value *> ExtraArgs;
- /// Kind of the reduction data.
- enum ReductionKind {
- RK_None, /// Not a reduction.
- RK_Arithmetic, /// Binary reduction data.
- RK_Min, /// Minimum reduction data.
- RK_UMin, /// Unsigned minimum reduction data.
- RK_Max, /// Maximum reduction data.
- RK_UMax, /// Unsigned maximum reduction data.
- };
/// Contains info about operation, like its opcode, left and right operands.
- class OperationData {
+ struct OperationData {
+ /// true if the operation is a reduced value, false if reduction operation.
+ bool IsReducedValue = false;
+
/// Opcode of the instruction.
unsigned Opcode = 0;
@@ -4672,21 +4666,12 @@ class HorizontalReduction {
/// Right operand of the reduction operation.
Value *RHS = nullptr;
- /// Kind of the reduction operation.
- ReductionKind Kind = RK_None;
- /// True if float point min/max reduction has no NaNs.
- bool NoNaN = false;
/// Checks if the reduction operation can be vectorized.
bool isVectorizable() const {
return LHS && RHS &&
- // We currently only support adds && min/max reductions.
- ((Kind == RK_Arithmetic &&
- (Opcode == Instruction::Add || Opcode == Instruction::FAdd)) ||
- ((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
- (Kind == RK_Min || Kind == RK_Max)) ||
- (Opcode == Instruction::ICmp &&
- (Kind == RK_UMin || Kind == RK_UMax)));
+ // We currently only support adds.
+ (Opcode == Instruction::Add || Opcode == Instruction::FAdd);
}
public:
@@ -4694,92 +4679,43 @@ class HorizontalReduction {
/// Construction for reduced values. They are identified by opcode only and
/// don't have associated LHS/RHS values.
- explicit OperationData(Value *V) : Kind(RK_None) {
+ explicit OperationData(Value *V) : IsReducedValue(true) {
if (auto *I = dyn_cast<Instruction>(V))
Opcode = I->getOpcode();
}
- /// Constructor for reduction operations with opcode and its left and
+ /// Constructor for binary reduction operations with opcode and its left and
/// right operands.
- OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind,
- bool NoNaN = false)
- : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind), NoNaN(NoNaN) {
- assert(Kind != RK_None && "One of the reduction operations is expected.");
- }
+ OperationData(unsigned Opcode, Value *LHS, Value *RHS)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS) {}
+
explicit operator bool() const { return Opcode; }
/// Get the index of the first operand.
unsigned getFirstOperandIndex() const {
assert(!!*this && "The opcode is not set.");
- switch (Kind) {
- case RK_Min:
- case RK_UMin:
- case RK_Max:
- case RK_UMax:
- return 1;
- case RK_Arithmetic:
- case RK_None:
- break;
- }
return 0;
}
/// Total number of operands in the reduction operation.
unsigned getNumberOfOperands() const {
- assert(Kind != RK_None && !!*this && LHS && RHS &&
+ assert(!IsReducedValue && !!*this && LHS && RHS &&
"Expected reduction operation.");
- switch (Kind) {
- case RK_Arithmetic:
- return 2;
- case RK_Min:
- case RK_UMin:
- case RK_Max:
- case RK_UMax:
- return 3;
- case RK_None:
- break;
- }
- llvm_unreachable("Reduction kind is not set");
+ return 2;
}
/// Expected number of uses for reduction operations/reduced values.
unsigned getRequiredNumberOfUses() const {
- assert(Kind != RK_None && !!*this && LHS && RHS &&
+ assert(!IsReducedValue && !!*this && LHS && RHS &&
"Expected reduction operation.");
- switch (Kind) {
- case RK_Arithmetic:
- return 1;
- case RK_Min:
- case RK_UMin:
- case RK_Max:
- case RK_UMax:
- return 2;
- case RK_None:
- break;
- }
- llvm_unreachable("Reduction kind is not set");
+ return 1;
}
/// Checks if instruction is associative and can be vectorized.
bool isAssociative(Instruction *I) const {
- assert(Kind != RK_None && *this && LHS && RHS &&
+ assert(!IsReducedValue && *this && LHS && RHS &&
"Expected reduction operation.");
- switch (Kind) {
- case RK_Arithmetic:
- return I->isAssociative();
- case RK_Min:
- case RK_Max:
- return Opcode == Instruction::ICmp ||
- cast<Instruction>(I->getOperand(0))->hasUnsafeAlgebra();
- case RK_UMin:
- case RK_UMax:
- assert(Opcode == Instruction::ICmp &&
- "Only integer compare operation is expected.");
- return true;
- case RK_None:
- break;
- }
- llvm_unreachable("Reduction kind is not set");
+ return I->isAssociative();
}
/// Checks if the reduction operation can be vectorized.
@@ -4790,17 +4726,18 @@ class HorizontalReduction {
/// Checks if two operation data are both a reduction op or both a reduced
/// value.
bool operator==(const OperationData &OD) {
- assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
+ assert(((IsReducedValue != OD.IsReducedValue) ||
+ ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
"One of the comparing operations is incorrect.");
- return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode);
+ return this == &OD ||
+ (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode);
}
bool operator!=(const OperationData &OD) { return !(*this == OD); }
void clear() {
+ IsReducedValue = false;
Opcode = 0;
LHS = nullptr;
RHS = nullptr;
- Kind = RK_None;
- NoNaN = false;
}
/// Get the opcode of the reduction operation.
@@ -4809,81 +4746,16 @@ class HorizontalReduction {
return Opcode;
}
- /// Get kind of reduction data.
- ReductionKind getKind() const { return Kind; }
Value *getLHS() const { return LHS; }
Value *getRHS() const { return RHS; }
- Type *getConditionType() const {
- switch (Kind) {
- case RK_Arithmetic:
- return nullptr;
- case RK_Min:
- case RK_Max:
- case RK_UMin:
- case RK_UMax:
- return CmpInst::makeCmpResultType(LHS->getType());
- case RK_None:
- break;
- }
- llvm_unreachable("Reduction kind is not set");
- }
/// Creates reduction operation with the current opcode.
Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const {
- assert(isVectorizable() &&
- "Expected add|fadd or min/max reduction operation.");
- Value *Cmp;
- switch (Kind) {
- case RK_Arithmetic:
- return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
- Name);
- case RK_Min:
- Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS)
- : Builder.CreateFCmpOLT(LHS, RHS);
- break;
- case RK_Max:
- Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS)
- : Builder.CreateFCmpOGT(LHS, RHS);
- break;
- case RK_UMin:
- assert(Opcode == Instruction::ICmp && "Expected integer types.");
- Cmp = Builder.CreateICmpULT(LHS, RHS);
- break;
- case RK_UMax:
- assert(Opcode == Instruction::ICmp && "Expected integer types.");
- Cmp = Builder.CreateICmpUGT(LHS, RHS);
- break;
- case RK_None:
- llvm_unreachable("Unknown reduction operation.");
- }
- return Builder.CreateSelect(Cmp, LHS, RHS, Name);
- }
- TargetTransformInfo::ReductionFlags getFlags() const {
- TargetTransformInfo::ReductionFlags Flags;
- Flags.NoNaN = NoNaN;
- switch (Kind) {
- case RK_Arithmetic:
- break;
- case RK_Min:
- Flags.IsSigned = Opcode == Instruction::ICmp;
- Flags.IsMaxOp = false;
- break;
- case RK_Max:
- Flags.IsSigned = Opcode == Instruction::ICmp;
- Flags.IsMaxOp = true;
- break;
- case RK_UMin:
- Flags.IsSigned = false;
- Flags.IsMaxOp = false;
- break;
- case RK_UMax:
- Flags.IsSigned = false;
- Flags.IsMaxOp = true;
- break;
- case RK_None:
- llvm_unreachable("Reduction kind is not set");
- }
- return Flags;
+ assert(!IsReducedValue &&
+ (Opcode == Instruction::FAdd || Opcode == Instruction::Add) &&
+ "Expected add|fadd reduction operation.");
+ return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS,
+ Name);
}
};
@@ -4925,32 +4797,8 @@ class HorizontalReduction {
Value *LHS;
Value *RHS;
- if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) {
- return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS,
- RK_Arithmetic);
- }
- if (auto *Select = dyn_cast<SelectInst>(V)) {
- // Look for a min/max pattern.
- if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
- return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
- } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
- return OperationData(Instruction::ICmp, LHS, RHS, RK_Min);
- } else if (m_OrdFMin(m_Value(LHS), m_Value(RHS)).match(Select) ||
- m_UnordFMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
- return OperationData(
- Instruction::FCmp, LHS, RHS, RK_Min,
- cast<Instruction>(Select->getCondition())->hasNoNaNs());
- } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
- return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
- } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
- return OperationData(Instruction::ICmp, LHS, RHS, RK_Max);
- } else if (m_OrdFMax(m_Value(LHS), m_Value(RHS)).match(Select) ||
- m_UnordFMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
- return OperationData(
- Instruction::FCmp, LHS, RHS, RK_Max,
- cast<Instruction>(Select->getCondition())->hasNoNaNs());
- }
- }
+ if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V))
+ return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS);
return OperationData(V);
}
@@ -5143,9 +4991,8 @@ public:
if (VectorizedTree) {
Builder.SetCurrentDebugLocation(Loc);
OperationData VectReductionData(ReductionData.getOpcode(),
- VectorizedTree, ReducedSubTree,
- ReductionData.getKind());
- VectorizedTree = VectReductionData.createOp(Builder, "op.rdx");
+ VectorizedTree, ReducedSubTree);
+ VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx");
propagateIRFlags(VectorizedTree, ReductionOps);
} else
VectorizedTree = ReducedSubTree;
@@ -5159,8 +5006,7 @@ public:
auto *I = cast<Instruction>(ReducedVals[i]);
Builder.SetCurrentDebugLocation(I->getDebugLoc());
OperationData VectReductionData(ReductionData.getOpcode(),
- VectorizedTree, I,
- ReductionData.getKind());
+ VectorizedTree, I);
VectorizedTree = VectReductionData.createOp(Builder);
propagateIRFlags(VectorizedTree, ReductionOps);
}
@@ -5171,9 +5017,8 @@ public:
for (auto *I : Pair.second) {
Builder.SetCurrentDebugLocation(I->getDebugLoc());
OperationData VectReductionData(ReductionData.getOpcode(),
- VectorizedTree, Pair.first,
- ReductionData.getKind());
- VectorizedTree = VectReductionData.createOp(Builder, "op.extra");
+ VectorizedTree, Pair.first);
+ VectorizedTree = VectReductionData.createOp(Builder, "bin.extra");
propagateIRFlags(VectorizedTree, I);
}
}
@@ -5194,58 +5039,19 @@ private:
Type *ScalarTy = FirstReducedVal->getType();
Type *VecTy = VectorType::get(ScalarTy, ReduxWidth);
- int PairwiseRdxCost;
- int SplittingRdxCost;
- switch (ReductionData.getKind()) {
- case RK_Arithmetic:
- PairwiseRdxCost =
- TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
- /*IsPairwiseForm=*/true);
- SplittingRdxCost =
- TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
- /*IsPairwiseForm=*/false);
- break;
- case RK_Min:
- case RK_Max:
- case RK_UMin:
- case RK_UMax: {
- Type *VecCondTy = CmpInst::makeCmpResultType(VecTy);
- bool IsUnsigned = ReductionData.getKind() == RK_UMin ||
- ReductionData.getKind() == RK_UMax;
- PairwiseRdxCost =
- TTI->getMinMaxReductionCost(VecTy, VecCondTy,
- /*IsPairwiseForm=*/true, IsUnsigned);
- SplittingRdxCost =
- TTI->getMinMaxReductionCost(VecTy, VecCondTy,
- /*IsPairwiseForm=*/false, IsUnsigned);
- break;
- }
- case RK_None:
- llvm_unreachable("Expected arithmetic or min/max reduction operation");
- }
+ int PairwiseRdxCost =
+ TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
+ /*IsPairwiseForm=*/true);
+ int SplittingRdxCost =
+ TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy,
+ /*IsPairwiseForm=*/false);
IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost;
int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost;
- int ScalarReduxCost;
- switch (ReductionData.getKind()) {
- case RK_Arithmetic:
- ScalarReduxCost =
- TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
- break;
- case RK_Min:
- case RK_Max:
- case RK_UMin:
- case RK_UMax:
- ScalarReduxCost =
- TTI->getCmpSelInstrCost(ReductionData.getOpcode(), ScalarTy) +
- TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
- CmpInst::makeCmpResultType(ScalarTy));
- break;
- case RK_None:
- llvm_unreachable("Expected arithmetic or min/max reduction operation");
- }
- ScalarReduxCost *= (ReduxWidth - 1);
+ int ScalarReduxCost =
+ (ReduxWidth - 1) *
+ TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost
<< " for reduction that starts with " << *FirstReducedVal
@@ -5267,7 +5073,7 @@ private:
if (!IsPairwiseReduction)
return createSimpleTargetReduction(
Builder, TTI, ReductionData.getOpcode(), VectorizedValue,
- ReductionData.getFlags(), RedOps);
+ TargetTransformInfo::ReductionFlags(), RedOps);
Value *TmpVec = VectorizedValue;
for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) {
@@ -5282,8 +5088,8 @@ private:
TmpVec, UndefValue::get(TmpVec->getType()), (RightMask),
"rdx.shuf.r");
OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
- RightShuf, ReductionData.getKind());
- TmpVec = VectReductionData.createOp(Builder, "op.rdx");
+ RightShuf);
+ TmpVec = VectReductionData.createOp(Builder, "bin.rdx");
propagateIRFlags(TmpVec, RedOps);
}
@@ -5444,11 +5250,9 @@ static bool tryToVectorizeHorReductionOrInstOperands(
auto *Inst = dyn_cast<Instruction>(V);
if (!Inst)
continue;
- auto *BI = dyn_cast<BinaryOperator>(Inst);
- auto *SI = dyn_cast<SelectInst>(Inst);
- if (BI || SI) {
+ if (auto *BI = dyn_cast<BinaryOperator>(Inst)) {
HorizontalReduction HorRdx;
- if (HorRdx.matchAssociativeReduction(P, Inst)) {
+ if (HorRdx.matchAssociativeReduction(P, BI)) {
if (HorRdx.tryToReduce(R, TTI)) {
Res = true;
// Set P to nullptr to avoid re-analysis of phi node in
@@ -5457,7 +5261,7 @@ static bool tryToVectorizeHorReductionOrInstOperands(
continue;
}
}
- if (P && BI) {
+ if (P) {
Inst = dyn_cast<Instruction>(BI->getOperand(0));
if (Inst == P)
Inst = dyn_cast<Instruction>(BI->getOperand(1));
OpenPOWER on IntegriCloud