diff options
author | Alexey Bataev <a.bataev@hotmail.com> | 2017-07-31 14:36:05 +0000 |
---|---|---|
committer | Alexey Bataev <a.bataev@hotmail.com> | 2017-07-31 14:36:05 +0000 |
commit | 0ab22bb991124b9bcd82440a77c562ea19bc98e4 (patch) | |
tree | 40bbf2ea9b6f55679dc111fbf5dce37aeb76d19a /llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | |
parent | 11ea6fdcd7dc3df0401ad28a1d3396fd1713d829 (diff) | |
download | bcm5719-llvm-0ab22bb991124b9bcd82440a77c562ea19bc98e4.tar.gz bcm5719-llvm-0ab22bb991124b9bcd82440a77c562ea19bc98e4.zip |
[SLP] Initial rework for min/max horizontal reduction vectorization, NFC.
Summary: All getReductionCost() functions are renamed to getArithmeticReductionCost() + added basic infrastructure to handle non-binary reduction operations.
Reviewers: spatel, mzolotukhin, Ayal, mkuper, gilr, hfinkel
Subscribers: RKSimon, llvm-commits
Differential Revision: https://reviews.llvm.org/D29402
llvm-svn: 309566
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 205 |
1 files changed, 158 insertions, 47 deletions
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index d8b0f1d99a6..05ee05af31e 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" @@ -48,6 +49,7 @@ #include <memory> using namespace llvm; +using namespace llvm::PatternMatch; using namespace slpvectorizer; #define SV_NAME "slp-vectorizer" @@ -4321,12 +4323,104 @@ class HorizontalReduction { // Use map vector to make stable output. MapVector<Instruction *, Value *> ExtraArgs; - BinaryOperator *ReductionRoot = nullptr; + /// Contains info about operation, like its opcode, left and right operands. + struct OperationData { + /// true if the operation is a reduced value, false if reduction operation. + bool IsReducedValue = false; + /// Opcode of the instruction. + unsigned Opcode = 0; + /// Left operand of the reduction operation. + Value *LHS = nullptr; + /// Right operand of the reduction operation. + Value *RHS = nullptr; + + /// Checks if the reduction operation can be vectorized. + bool isVectorizable() const { + return LHS && RHS && + // We currently only support adds. + (Opcode == Instruction::Add || Opcode == Instruction::FAdd); + } + + public: + explicit OperationData() = default; + /// Construction for reduced values. They are identified by opcode only and + /// don't have associated LHS/RHS values. + explicit OperationData(Value *V) : IsReducedValue(true) { + if (auto *I = dyn_cast<Instruction>(V)) + Opcode = I->getOpcode(); + } + /// Constructor for binary reduction operations with opcode and its left and + /// right operands. + OperationData(unsigned Opcode, Value *LHS, Value *RHS) + : IsReducedValue(false), Opcode(Opcode), LHS(LHS), RHS(RHS) {} + explicit operator bool() const { return Opcode; } + /// Get the index of the first operand. + unsigned getFirstOperandIndex() const { + assert(!!*this && "The opcode is not set."); + return 0; + } + /// Total number of operands in the reduction operation. + unsigned getNumberOfOperands() const { + assert(!IsReducedValue && !!*this && LHS && RHS && + "Expected reduction operation."); + return 2; + } + /// Expected number of uses for reduction operations/reduced values. + unsigned getRequiredNumberOfUses() const { + assert(!IsReducedValue && !!*this && LHS && RHS && + "Expected reduction operation."); + return 1; + } + /// Checks if instruction is associative and can be vectorized. + bool isAssociative(Instruction *I) const { + assert(!IsReducedValue && *this && LHS && RHS && + "Expected reduction operation."); + return I->isAssociative(); + } + /// Checks if the reduction operation can be vectorized. + bool isVectorizable(Instruction *I) const { + return isVectorizable() && isAssociative(I); + } + + /// Checks if two operation data are both a reduction op or both a reduced + /// value. + bool operator==(const OperationData &OD) { + assert((IsReducedValue != OD.IsReducedValue) || + ((!LHS == !OD.LHS) && (!RHS == !OD.RHS)) && + "One of the comparing operations is incorrect."); + return this == &OD || + (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode); + } + bool operator!=(const OperationData &OD) { return !(*this == OD); } + void clear() { + IsReducedValue = false; + Opcode = 0; + LHS = nullptr; + RHS = nullptr; + } + /// Get the opcode of the reduction operation. + unsigned getOpcode() const { + assert(isVectorizable() && "Expected vectorizable operation."); + return Opcode; + } + Value *getLHS() const { return LHS; } + Value *getRHS() const { return RHS; } + /// Creates reduction operation with the current opcode. + Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const { + assert(!IsReducedValue && + (Opcode == Instruction::FAdd || Opcode == Instruction::Add) && + "Expected add|fadd reduction operation."); + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, + Name); + } + }; + + Instruction *ReductionRoot = nullptr; - /// The opcode of the reduction. - Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd; - /// The opcode of the values we perform a reduction on. - unsigned ReducedValueOpcode = 0; + /// The operation data of the reduction operation. + OperationData ReductionData; + /// The operation data of the values we perform a reduction on. + OperationData ReducedValueData; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. bool IsPairwiseReduction = false; @@ -4351,55 +4445,65 @@ class HorizontalReduction { } } + static OperationData getOperationData(Value *V) { + if (!V) + return OperationData(); + + Value *LHS; + Value *RHS; + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) + return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS); + return OperationData(V); + } + public: HorizontalReduction() = default; /// \brief Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { + bool matchAssociativeReduction(PHINode *Phi, Instruction *B) { assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); + ReductionData = getOperationData(B); + // We could have a initial reductions that is not an add. // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (B->getOperand(0) == Phi) { + if (ReductionData.getLHS() == Phi) { Phi = nullptr; - B = dyn_cast<BinaryOperator>(B->getOperand(1)); - } else if (B->getOperand(1) == Phi) { + B = dyn_cast<Instruction>(ReductionData.getRHS()); + ReductionData = getOperationData(B); + } else if (ReductionData.getRHS() == Phi) { Phi = nullptr; - B = dyn_cast<BinaryOperator>(B->getOperand(0)); + B = dyn_cast<Instruction>(ReductionData.getLHS()); + ReductionData = getOperationData(B); } } - if (!B) + if (!ReductionData.isVectorizable(B)) return false; Type *Ty = B->getType(); if (!isValidElementType(Ty)) return false; - ReductionOpcode = B->getOpcode(); - ReducedValueOpcode = 0; + ReducedValueData.clear(); ReductionRoot = B; - // We currently only support adds. - if ((ReductionOpcode != Instruction::Add && - ReductionOpcode != Instruction::FAdd) || - !B->isAssociative()) - return false; - // Post order traverse the reduction tree starting at B. We only handle true - // trees containing only binary operators or selects. + // trees containing only binary operators. SmallVector<std::pair<Instruction *, unsigned>, 32> Stack; - Stack.push_back(std::make_pair(B, 0)); + Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex())); + const unsigned NUses = ReductionData.getRequiredNumberOfUses(); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVist = Stack.back().second++; - bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; + OperationData OpData = getOperationData(TreeN); + bool IsReducedValue = OpData != ReductionData; // Postorder vist. - if (EdgeToVist == 2 || IsReducedValue) { + if (IsReducedValue || EdgeToVist == OpData.getNumberOfOperands()) { if (IsReducedValue) ReducedVals.push_back(TreeN); else { @@ -4428,12 +4532,13 @@ public: Value *NextV = TreeN->getOperand(EdgeToVist); if (NextV != Phi) { auto *I = dyn_cast<Instruction>(NextV); + OpData = getOperationData(I); // Continue analysis if the next operand is a reduction operation or // (possibly) a reduced value. If the reduced value opcode is not set, // the first met operation != reduction operation is considered as the // reduced value class. - if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || - I->getOpcode() == ReductionOpcode)) { + if (I && (!ReducedValueData || OpData == ReducedValueData || + OpData == ReductionData)) { // Only handle trees in the current basic block. if (I->getParent() != B->getParent()) { // I is an extra argument for TreeN (its parent operation). @@ -4441,32 +4546,32 @@ public: continue; } - // Each tree node needs to have one user except for the ultimate - // reduction. - if (!I->hasOneUse() && I != B) { + // Each tree node needs to have minimal number of users except for the + // ultimate reduction. + if (!I->hasNUses(NUses) && I != B) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - if (I->getOpcode() == ReductionOpcode) { + if (OpData == ReductionData) { // We need to be able to reassociate the reduction operations. - if (!I->isAssociative()) { + if (!OpData.isAssociative(I)) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - } else if (ReducedValueOpcode && - ReducedValueOpcode != I->getOpcode()) { + } else if (ReducedValueData && + ReducedValueData != OpData) { // Make sure that the opcodes of the operations that we are going to // reduce match. // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; - } else if (!ReducedValueOpcode) - ReducedValueOpcode = I->getOpcode(); + } else if (!ReducedValueData) + ReducedValueData = OpData; - Stack.push_back(std::make_pair(I, 0)); + Stack.push_back(std::make_pair(I, OpData.getFirstOperandIndex())); continue; } } @@ -4539,8 +4644,9 @@ public: emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps, TTI); if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - ReducedSubTree, "bin.rdx"); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, ReducedSubTree); + VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx"); propagateIRFlags(VectorizedTree, ReductionOps); } else VectorizedTree = ReducedSubTree; @@ -4553,8 +4659,9 @@ public: for (; i < NumReducedVals; ++i) { auto *I = cast<Instruction>(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = - Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, I); + VectorizedTree = VectReductionData.createOp(Builder); propagateIRFlags(VectorizedTree, ReductionOps); } for (auto &Pair : ExternallyUsedValues) { @@ -4563,8 +4670,9 @@ public: // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - Pair.first, "bin.extra"); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, Pair.first); + VectorizedTree = VectReductionData.createOp(Builder, "bin.extra"); propagateIRFlags(VectorizedTree, I); } } @@ -4586,16 +4694,18 @@ private: Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); int PairwiseRdxCost = - TTI->getArithmeticReductionCost(ReductionOpcode, VecTy, true); + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); int SplittingRdxCost = - TTI->getArithmeticReductionCost(ReductionOpcode, VecTy, false); + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; int ScalarReduxCost = (ReduxWidth - 1) * - TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -4616,7 +4726,7 @@ private: if (!IsPairwiseReduction) return createSimpleTargetReduction( - Builder, TTI, ReductionOpcode, VectorizedValue, + Builder, TTI, ReductionData.getOpcode(), VectorizedValue, TargetTransformInfo::ReductionFlags(), RedOps); Value *TmpVec = VectorizedValue; @@ -4631,8 +4741,9 @@ private: Value *RightShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); - TmpVec = - Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, "bin.rdx"); + OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf, + RightShuf); + TmpVec = VectReductionData.createOp(Builder, "bin.rdx"); propagateIRFlags(TmpVec, RedOps); } |