diff options
| -rw-r--r-- | mlir/include/mlir/Analysis/NestedMatcher.h | 165 | ||||
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 14 | ||||
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 158 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 71 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 192 |
6 files changed, 291 insertions, 320 deletions
diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 161bb217a10..2a1c469348d 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -20,20 +20,19 @@ #include "mlir/IR/InstVisitor.h" #include "llvm/Support/Allocator.h" -#include <utility> namespace mlir { -struct NestedPatternStorage; -struct NestedMatchStorage; +struct NestedPattern; class Instruction; -/// An NestedPattern captures nested patterns. It is used in conjunction with -/// a scoped NestedPatternContext which is an llvm::BumPtrAllocator that -/// handles memory allocations efficiently and avoids ownership issues. +/// An NestedPattern captures nested patterns in the IR. +/// It is used in conjunction with a scoped NestedPatternContext which is an +/// llvm::BumpPtrAllocator that handles memory allocations efficiently and +/// avoids ownership issues. /// -/// In order to use NestedPatterns, first create a scoped context. When the -/// context goes out of scope, everything is freed. +/// In order to use NestedPatterns, first create a scoped context. +/// When the context goes out of scope, everything is freed. /// This design simplifies the API by avoiding references to the context and /// makes it clear that references to matchers must not escape. /// @@ -45,109 +44,145 @@ class Instruction; /// // do work on matches /// } // everything is freed /// - -/// Recursive abstraction for matching results. -/// Provides iteration over the Instruction* captured by a Matcher. /// -/// Implemented as a POD value-type with underlying storage pointer. -/// The underlying storage lives in a scoped bumper allocator whose lifetime -/// is managed by an RAII NestedPatternContext. -/// This is used by value everywhere. +/// Nested abstraction for matching results. +/// Provides access to the nested Instruction* captured by a Matcher. +/// +/// A NestedMatch contains an Instruction* and the children NestedMatch and is +/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose +/// lifetime is managed by an RAII NestedPatternContext. struct NestedMatch { - using EntryType = std::pair<Instruction *, NestedMatch>; - using iterator = EntryType *; - - static NestedMatch build(ArrayRef<NestedMatch::EntryType> elements = {}); + static NestedMatch build(Instruction *instruction, + ArrayRef<NestedMatch> nestedMatches); NestedMatch(const NestedMatch &) = default; NestedMatch &operator=(const NestedMatch &) = default; - explicit operator bool() { return !empty(); } + explicit operator bool() { return matchedInstruction != nullptr; } - iterator begin(); - iterator end(); - EntryType &front(); - EntryType &back(); - unsigned size() { return end() - begin(); } - unsigned empty() { return size() == 0; } + Instruction *getMatchedInstruction() { return matchedInstruction; } + ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; } private: friend class NestedPattern; friend class NestedPatternContext; - friend class NestedMatchStorage; /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); - NestedMatch(NestedMatchStorage *storage) : storage(storage){}; - - /// Copy the specified array of elements into memory managed by our bump - /// pointer allocator. The elements are all PODs by constructions. - static NestedMatch copyInto(ArrayRef<NestedMatch::EntryType> elements); + NestedMatch() = default; - /// POD payload. - NestedMatchStorage *storage; + /// Payload, holds a NestedMatch and all its children along this branch. + Instruction *matchedInstruction; + ArrayRef<NestedMatch> matchedChildren; }; -/// A NestedPattern is a special type of InstWalker that: +/// A NestedPattern is a nested InstWalker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); -/// 3. TODO(ntv) Optionally applies actions (lambda). +/// 3. TODO(ntv) optionally applies actions (lambda). /// -/// Implemented as a POD value-type with underlying storage pointer. -/// The underlying storage lives in a scoped bumper allocator whose lifetime -/// is managed by an RAII NestedPatternContext. -/// This should be used by value everywhere. +/// Nested patterns are meant to capture imperfectly nested loops while matching +/// properties over the whole loop nest. For instance, in vectorization we are +/// interested in capturing all the imperfectly nested loops of a certain type +/// and such that all the load and stores have certain access patterns along the +/// loops' induction variables). Such NestedMatches are first captured using the +/// `match` function and are later processed to analyze properties and apply +/// transformations in a non-greedy way. +/// +/// The NestedMatches captured in the IR can grow large, especially after +/// aggressive unrolling. As experience has shown, it is generally better to use +/// a plain InstWalker to match flat patterns but the current implementation is +/// competitive nonetheless. using FilterFunctionType = std::function<bool(const Instruction &)>; static bool defaultFilterFunction(const Instruction &) { return true; }; -struct NestedPattern : public InstWalker<NestedPattern> { +struct NestedPattern { NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested, FilterFunctionType filter = defaultFilterFunction); NestedPattern(const NestedPattern &) = default; NestedPattern &operator=(const NestedPattern &) = default; - /// Returns all the matches in `function`. - NestedMatch match(Function *function); + /// Returns all the top-level matches in `function`. + void match(Function *function, SmallVectorImpl<NestedMatch> *matches) { + State state(*this, matches); + state.walkPostOrder(function); + } - /// Returns all the matches nested under `instruction`. - NestedMatch match(Instruction *instruction); + /// Returns all the top-level matches in `inst`. + void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) { + State state(*this, matches); + state.walkPostOrder(inst); + } - unsigned getDepth(); + /// Returns the depth of the pattern. + unsigned getDepth() const; private: friend class NestedPatternContext; - friend InstWalker<NestedPattern>; + friend class NestedMatch; + friend class InstWalker<NestedPattern>; + friend struct State; + + /// Helper state that temporarily holds matches for the next level of nesting. + struct State : public InstWalker<State> { + State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches) + : pattern(pattern), matches(matches) {} + void visitForInst(ForInst *forInst) { pattern.matchOne(forInst, matches); } + void visitOperationInst(OperationInst *opInst) { + pattern.matchOne(opInst, matches); + } + + private: + NestedPattern &pattern; + SmallVectorImpl<NestedMatch> *matches; + }; /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); - Instruction::Kind getKind(); - ArrayRef<NestedPattern> getNestedPatterns(); - FilterFunctionType getFilterFunction(); - - void matchOne(Instruction *elem); - - void visitForInst(ForInst *forInst) { matchOne(forInst); } - void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } - - /// POD paylod. - /// Storage for the PatternMatcher. - NestedPatternStorage *storage; - - // By-value POD wrapper to underlying storage pointer. - NestedMatch matches; + /// Matches this pattern against a single `inst` and fills matches with the + /// result. + void matchOne(Instruction *inst, SmallVectorImpl<NestedMatch> *matches); + + /// Instruction kind matched by this pattern. + Instruction::Kind kind; + + /// Nested patterns to be matched. + ArrayRef<NestedPattern> nestedPatterns; + + /// Extra filter function to apply to prune patterns as the IR is walked. + FilterFunctionType filter; + + /// skip is an implementation detail needed so that we can implement match + /// without switching on the type of the Instruction. The idea is that a + /// NestedPattern first checks if it matches locally and then recursively + /// applies its nested matchers to its elem->nested. Since we want to rely on + /// the InstWalker impl rather than duplicate its the logic, we allow an + /// off-by-one traversal to account for the fact that we write: + /// + /// void match(Instruction *elem) { + /// for (auto &c : getNestedPatterns()) { + /// NestedPattern childPattern(...); + /// ^~~~ Needs off-by-one skip. + /// + Instruction *skip; }; /// RAII structure to transparently manage the bump allocator for -/// NestedPattern and NestedMatch classes. +/// NestedPattern and NestedMatch classes. This avoids passing a context to +/// all the API functions. struct NestedPatternContext { NestedPatternContext() { - NestedPattern::allocator() = &allocator; + assert(NestedMatch::allocator() == nullptr && + "Only a single NestedPatternContext is supported"); + assert(NestedPattern::allocator() == nullptr && + "Only a single NestedPatternContext is supported"); NestedMatch::allocator() = &allocator; + NestedPattern::allocator() = &allocator; } ~NestedPatternContext() { - NestedPattern::allocator() = nullptr; NestedMatch::allocator() = nullptr; + NestedPattern::allocator() = nullptr; } llvm::BumpPtrAllocator allocator; }; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 07c903a6613..7d88a3d9b9f 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -242,7 +242,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, // No vectorization across conditionals for now. auto conditionals = matcher::If(); auto *forInst = const_cast<ForInst *>(&loop); - auto conditionalsMatched = conditionals.match(forInst); + SmallVector<NestedMatch, 8> conditionalsMatched; + conditionals.match(forInst, &conditionalsMatched); if (!conditionalsMatched.empty()) { return false; } @@ -252,21 +253,24 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, auto &opInst = cast<OperationInst>(inst); return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>(); }); - auto regionsMatched = regions.match(forInst); + SmallVector<NestedMatch, 8> regionsMatched; + regions.match(forInst, ®ionsMatched); if (!regionsMatched.empty()) { return false; } auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); - auto vectorTransfersMatched = vectorTransfers.match(forInst); + SmallVector<NestedMatch, 8> vectorTransfersMatched; + vectorTransfers.match(forInst, &vectorTransfersMatched); if (!vectorTransfersMatched.empty()) { return false; } auto loadAndStores = matcher::Op(matcher::isLoadOrStore); - auto loadAndStoresMatched = loadAndStores.match(forInst); + SmallVector<NestedMatch, 8> loadAndStoresMatched; + loadAndStores.match(forInst, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { - auto *op = cast<OperationInst>(ls.first); + auto *op = cast<OperationInst>(ls.getMatchedInstruction()); auto load = op->dyn_cast<LoadOp>(); auto store = op->dyn_cast<StoreOp>(); // Only scalar types are considered vectorizable, all load/store must be diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 491a9bef1b9..43e3b332a58 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -20,93 +20,49 @@ #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" -namespace mlir { - -/// Underlying storage for NestedMatch. -struct NestedMatchStorage { - MutableArrayRef<NestedMatch::EntryType> matches; -}; - -/// Underlying storage for NestedPattern. -struct NestedPatternStorage { - NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c, - FilterFunctionType filter, Instruction *skip) - : kind(k), nestedPatterns(c), filter(filter), skip(skip) {} - - Instruction::Kind kind; - ArrayRef<NestedPattern> nestedPatterns; - FilterFunctionType filter; - /// skip is needed so that we can implement match without switching on the - /// type of the Instruction. - /// The idea is that a NestedPattern first checks if it matches locally - /// and then recursively applies its nested matchers to its elem->nested. - /// Since we want to rely on the InstWalker impl rather than duplicate its - /// the logic, we allow an off-by-one traversal to account for the fact that - /// we write: - /// - /// void match(Instruction *elem) { - /// for (auto &c : getNestedPatterns()) { - /// NestedPattern childPattern(...); - /// ^~~~ Needs off-by-one skip. - /// - Instruction *skip; -}; - -} // end namespace mlir - using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; + thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } -NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) { - auto *matches = - allocator()->Allocate<NestedMatch::EntryType>(elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), matches); - auto *storage = allocator()->Allocate<NestedMatchStorage>(); - new (storage) NestedMatchStorage(); - storage->matches = - MutableArrayRef<NestedMatch::EntryType>(matches, elements.size()); +NestedMatch NestedMatch::build(Instruction *instruction, + ArrayRef<NestedMatch> nestedMatches) { auto *result = allocator()->Allocate<NestedMatch>(); - new (result) NestedMatch(storage); + auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size()); + std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); + new (result) NestedMatch(); + result->matchedInstruction = instruction; + result->matchedChildren = + ArrayRef<NestedMatch>(children, nestedMatches.size()); return *result; } -NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); } -NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); } -NestedMatch::EntryType &NestedMatch::front() { - return *storage->matches.begin(); -} -NestedMatch::EntryType &NestedMatch::back() { - return *(storage->matches.begin() + size() - 1); -} - -/// Calls walk on `function`. -NestedMatch NestedPattern::match(Function *function) { - assert(!matches && "NestedPattern already matched!"); - this->walkPostOrder(function); - return matches; +llvm::BumpPtrAllocator *&NestedPattern::allocator() { + thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; } -/// Calls walk on `instruction`. -NestedMatch NestedPattern::match(Instruction *instruction) { - assert(!matches && "NestedPattern already matched!"); - this->walkPostOrder(instruction); - return matches; +NestedPattern::NestedPattern(Instruction::Kind k, + ArrayRef<NestedPattern> nested, + FilterFunctionType filter) + : kind(k), nestedPatterns(ArrayRef<NestedPattern>(nested)), filter(filter) { + auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); + std::uninitialized_copy(nested.begin(), nested.end(), newNested); + nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); } -unsigned NestedPattern::getDepth() { - auto nested = getNestedPatterns(); - if (nested.empty()) { +unsigned NestedPattern::getDepth() const { + if (nestedPatterns.empty()) { return 1; } unsigned depth = 0; - for (auto c : nested) { + for (auto &c : nestedPatterns) { depth = std::max(depth, c.getDepth()); } return depth + 1; @@ -122,69 +78,39 @@ unsigned NestedPattern::getDepth() { /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void NestedPattern::matchOne(Instruction *elem) { - if (storage->skip == elem) { +void NestedPattern::matchOne(Instruction *inst, + SmallVectorImpl<NestedMatch> *matches) { + if (skip == inst) { return; } // Structural filter - if (elem->getKind() != getKind()) { + if (inst->getKind() != kind) { return; } // Local custom filter function - if (!getFilterFunction()(*elem)) { + if (!filter(*inst)) { return; } - SmallVector<NestedMatch::EntryType, 8> nestedEntries; - for (auto c : getNestedPatterns()) { - /// We create a new nestedPattern here because a matcher holds its - /// results. So we concretely need multiple copies of a given matcher, one - /// for each matching result. - NestedPattern nestedPattern(c); + if (nestedPatterns.empty()) { + SmallVector<NestedMatch, 8> nestedMatches; + matches->push_back(NestedMatch::build(inst, nestedMatches)); + return; + } + // Take a copy of each nested pattern so we can match it. + for (auto nestedPattern : nestedPatterns) { + SmallVector<NestedMatch, 8> nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. - nestedPattern.storage->skip = elem; - nestedPattern.walkPostOrder(elem); - if (!nestedPattern.matches) { + nestedPattern.skip = inst; + nestedPattern.match(inst, &nestedMatches); + // If we could not match even one of the specified nestedPattern, early exit + // as this whole branch is not a match. + if (nestedMatches.empty()) { return; } - for (auto m : nestedPattern.matches) { - nestedEntries.push_back(m); - } + matches->push_back(NestedMatch::build(inst, nestedMatches)); } - - SmallVector<NestedMatch::EntryType, 8> newEntries( - matches.storage->matches.begin(), matches.storage->matches.end()); - newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries))); - matches = NestedMatch::build(newEntries); -} - -llvm::BumpPtrAllocator *&NestedPattern::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; - return allocator; -} - -NestedPattern::NestedPattern(Instruction::Kind k, - ArrayRef<NestedPattern> nested, - FilterFunctionType filter) - : storage(allocator()->Allocate<NestedPatternStorage>()), - matches(NestedMatch::build({})) { - auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size()); - std::uninitialized_copy(nested.begin(), nested.end(), newChildren); - // Initialize with placement new. - new (storage) NestedPatternStorage( - k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter, - nullptr /* skip */); -} - -Instruction::Kind NestedPattern::getKind() { return storage->kind; } - -ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() { - return storage->nestedPatterns; -} - -FilterFunctionType NestedPattern::getFilterFunction() { - return storage->filter; } static bool isAffineIfOp(const Instruction &inst) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2744b1d624c..432ad1f39b8 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -201,9 +201,6 @@ struct MaterializeVectorsPass : public FunctionPass { PassResult runOnFunction(Function *f) override; - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext mlContext; - static char passID; }; @@ -744,6 +741,9 @@ static bool materialize(Function *f, } PassResult MaterializeVectorsPass::runOnFunction(Function *f) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + // TODO(ntv): Check to see if this supports arbitrary top-level code. if (f->getBlocks().size() != 1) return success(); @@ -768,10 +768,11 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { return matcher::operatesOnSuperVectors(opInst, subVectorType); }; auto pat = Op(filter); - auto matches = pat.match(f); + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); SetVector<OperationInst *> terminators; for (auto m : matches) { - terminators.insert(cast<OperationInst>(m.first)); + terminators.insert(cast<OperationInst>(m.getMatchedInstruction())); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index a01b8fdf216..a9b9752ef51 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -30,6 +30,7 @@ #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Passes.h" +#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -94,9 +95,6 @@ struct VectorizerTestPass : public FunctionPass { void testComposeMaps(Function *f); void testNormalizeMaps(Function *f); - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext MLContext; - static char passID; }; @@ -128,9 +126,10 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { return true; }; auto pat = Op(filter); - auto matches = pat.match(f); + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.first); + auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -153,7 +152,7 @@ static std::string toString(Instruction *inst) { return res; } -static NestedMatch matchTestSlicingOps(Function *f) { +static NestedPattern patternTestSlicingOps() { // Just use a custom op name for this test, it makes life easier. constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; @@ -163,17 +162,18 @@ static NestedMatch matchTestSlicingOps(Function *f) { const auto &opInst = cast<OperationInst>(inst); return opInst.getName().getStringRef() == kTestSlicingOpName; }; - auto pat = Op(filter); - return pat.match(f); + return Op(filter); } void VectorizerTestPass::testBackwardSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Instruction *> backwardSlice; - getBackwardSlice(m.first, &backwardSlice); + getBackwardSlice(m.getMatchedInstruction(), &backwardSlice); auto strs = map(toString, backwardSlice); - outs() << "\nmatched: " << *m.first << " backward static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() + << " backward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -181,12 +181,14 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { } void VectorizerTestPass::testForwardSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Instruction *> forwardSlice; - getForwardSlice(m.first, &forwardSlice); + getForwardSlice(m.getMatchedInstruction(), &forwardSlice); auto strs = map(toString, forwardSlice); - outs() << "\nmatched: " << *m.first << " forward static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() + << " forward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -194,11 +196,12 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { } void VectorizerTestPass::testSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector<Instruction *> staticSlice = getSlice(m.first); + SetVector<Instruction *> staticSlice = getSlice(m.getMatchedInstruction()); auto strs = map(toString, staticSlice); - outs() << "\nmatched: " << *m.first << " static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() << " static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -214,12 +217,12 @@ static bool customOpWithAffineMapAttribute(const Instruction &inst) { void VectorizerTestPass::testComposeMaps(Function *f) { using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); - auto matches = pattern.match(f); + SmallVector<NestedMatch, 8> matches; + pattern.match(f, &matches); SmallVector<AffineMap, 4> maps; maps.reserve(matches.size()); - std::reverse(matches.begin(), matches.end()); - for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.first); + for (auto m : llvm::reverse(matches)) { + auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast<AffineMapAttr>() .getValue(); @@ -248,29 +251,31 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); - auto toErase = pattern.match(f); - std::reverse(toErase.begin(), toErase.end()); + SmallVector<NestedMatch, 8> toErase; + pattern.match(f, &toErase); { // Compose maps. auto pattern = Op(singleResultAffineApplyOpWithoutUses); - for (auto m : pattern.match(f)) { - auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>(); - FuncBuilder b(m.first); - - using ValueTy = decltype(*(app->getOperands().begin())); - SmallVector<Value *, 8> operands = - functional::map([](ValueTy v) { return static_cast<Value *>(v); }, - app->getOperands().begin(), app->getOperands().end()); + SmallVector<NestedMatch, 8> matches; + pattern.match(f, &matches); + for (auto m : matches) { + auto app = + cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>(); + FuncBuilder b(m.getMatchedInstruction()); + SmallVector<Value *, 8> operands(app->getOperands()); makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); } } // We should now be able to erase everything in reverse order in this test. - for (auto m : toErase) { - m.first->erase(); + for (auto m : llvm::reverse(toErase)) { + m.getMatchedInstruction()->erase(); } } PassResult VectorizerTestPass::runOnFunction(Function *f) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + // Only support single block functions at this point. if (f->getBlocks().size() != 1) return success(); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index cfde1ecf0a8..73893599d17 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -655,9 +655,6 @@ struct Vectorize : public FunctionPass { PassResult runOnFunction(Function *f) override; - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext MLContext; - static char passID; }; @@ -703,13 +700,13 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, /// 3. account for impact of vectorization on maximal loop fusion. /// Then we can quantify the above to build a cost model and search over /// strategies. -static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern, - unsigned patternDepth, +static bool analyzeProfitability(ArrayRef<NestedMatch> matches, + unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForInst>(m.first); - bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth, - strategy); + auto *loop = cast<ForInst>(m.getMatchedInstruction()); + bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, + patternDepth, strategy); if (fail) { return fail; } @@ -875,9 +872,10 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, state->terminators.count(opInst) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); - auto matches = loadAndStores.match(loop); - for (auto ls : matches) { - auto *opInst = cast<OperationInst>(ls.first); + SmallVector<NestedMatch, 8> loadAndStoresMatches; + loadAndStores.match(loop, &loadAndStoresMatches); + for (auto ls : loadAndStoresMatches) { + auto *opInst = cast<OperationInst>(ls.getMatchedInstruction()); auto load = opInst->dyn_cast<LoadOp>(); auto store = opInst->dyn_cast<StoreOp>(); LLVM_DEBUG(opInst->print(dbgs())); @@ -907,15 +905,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { } /// Forward-declaration. -static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state); +static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, + VectorizationState *state); /// Apply vectorization of `loop` according to `state`. This is only triggered /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. -static bool doVectorize(NestedMatch::EntryType oneMatch, - VectorizationState *state) { - ForInst *loop = cast<ForInst>(oneMatch.first); - NestedMatch childrenMatches = oneMatch.second; +static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { + ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction()); + auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. auto fail = vectorizeNonRoot(childrenMatches, state); @@ -949,7 +947,8 @@ static bool doVectorize(NestedMatch::EntryType oneMatch, /// Non-root pattern iterates over the matches at this level, calls doVectorize /// and exits early if anything below fails. -static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) { +static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, + VectorizationState *state) { for (auto m : matches) { auto fail = doVectorize(m, state); if (fail) { @@ -1185,99 +1184,100 @@ static bool vectorizeOperations(VectorizationState *state) { /// The root match thus needs to maintain a clone for handling failure. /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. -static bool vectorizeRootMatches(NestedMatch matches, - VectorizationStrategy *strategy) { - for (auto m : matches) { - auto *loop = cast<ForInst>(m.first); - VectorizationState state; - state.strategy = strategy; - - // Since patterns are recursive, they can very well intersect. - // Since we do not want a fully greedy strategy in general, we decouple - // pattern matching, from profitability analysis, from application. - // As a consequence we must check that each root pattern is still - // vectorizable. If a pattern is not vectorizable anymore, we just skip it. - // TODO(ntv): implement a non-greedy profitability analysis that keeps only - // non-intersecting patterns. - if (!isVectorizableLoop(*loop)) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); - continue; - } - FuncBuilder builder(loop); // builder to insert in place of loop - ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); - auto fail = doVectorize(m, &state); - /// Sets up error handling for this root loop. This is how the root match - /// maintains a clone for handling failure and restores the proper state via - /// RAII. - ScopeGuard sg2([&fail, loop, clonedLoop]() { - if (fail) { - loop->getInductionVar()->replaceAllUsesWith( - clonedLoop->getInductionVar()); - loop->erase(); - } else { - clonedLoop->erase(); - } - }); +static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { + auto *loop = cast<ForInst>(m.getMatchedInstruction()); + VectorizationState state; + state.strategy = strategy; + + // Since patterns are recursive, they can very well intersect. + // Since we do not want a fully greedy strategy in general, we decouple + // pattern matching, from profitability analysis, from application. + // As a consequence we must check that each root pattern is still + // vectorizable. If a pattern is not vectorizable anymore, we just skip it. + // TODO(ntv): implement a non-greedy profitability analysis that keeps only + // non-intersecting patterns. + if (!isVectorizableLoop(*loop)) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); + return true; + } + FuncBuilder builder(loop); // builder to insert in place of loop + ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); + auto fail = doVectorize(m, &state); + /// Sets up error handling for this root loop. This is how the root match + /// maintains a clone for handling failure and restores the proper state via + /// RAII. + ScopeGuard sg2([&fail, loop, clonedLoop]() { if (fail) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize"); - continue; + loop->getInductionVar()->replaceAllUsesWith( + clonedLoop->getInductionVar()); + loop->erase(); + } else { + clonedLoop->erase(); } + }); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize"); + return true; + } - // Form the root operationsthat have been set in the replacementMap. - // For now, these roots are the loads for which vector_transfer_read - // operations have been inserted. - auto getDefiningInst = [](const Value *val) { - return const_cast<Value *>(val)->getDefiningInst(); - }; - using ReferenceTy = decltype(*(state.replacementMap.begin())); - auto getKey = [](ReferenceTy it) { return it.first; }; - auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); - - // Vectorize the root operations and everything reached by use-def chains - // except the terminators (store instructions) that need to be - // post-processed separately. - fail = vectorizeOperations(&state); - if (fail) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); - continue; - } + // Form the root operationsthat have been set in the replacementMap. + // For now, these roots are the loads for which vector_transfer_read + // operations have been inserted. + auto getDefiningInst = [](const Value *val) { + return const_cast<Value *>(val)->getDefiningInst(); + }; + using ReferenceTy = decltype(*(state.replacementMap.begin())); + auto getKey = [](ReferenceTy it) { return it.first; }; + auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); + + // Vectorize the root operations and everything reached by use-def chains + // except the terminators (store instructions) that need to be + // post-processed separately. + fail = vectorizeOperations(&state); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); + return true; + } - // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { - if (fail) { - return; - } - FuncBuilder b(inst); - auto *res = vectorizeOneOperationInst(&b, inst, &state); - if (res == nullptr) { - fail = true; - } - }; - apply(vectorizeOrFail, state.terminators); + // Finally, vectorize the terminators. If anything fails to vectorize, skip. + auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { if (fail) { - LLVM_DEBUG( - dbgs() << "\n[early-vect]+++++ failed to vectorize terminators"); - continue; - } else { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); + return; } - - // Finish this vectorization pattern. - state.finishVectorizationPattern(); + FuncBuilder b(inst); + auto *res = vectorizeOneOperationInst(&b, inst, &state); + if (res == nullptr) { + fail = true; + } + }; + apply(vectorizeOrFail, state.terminators); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminators"); + return true; + } else { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); } + + // Finish this vectorization pattern. + state.finishVectorizationPattern(); return false; } /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. PassResult Vectorize::runOnFunction(Function *f) { - for (auto pat : makePatterns()) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + + for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); LLVM_DEBUG(f->print(dbgs())); unsigned patternDepth = pat.getDepth(); - auto matches = pat.match(f); + + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { @@ -1285,16 +1285,16 @@ PassResult Vectorize::runOnFunction(Function *f) { // TODO(ntv): depending on profitability, elect to reduce the vector size. strategy.vectorSizes.assign(clVirtualVectorSize.begin(), clVirtualVectorSize.end()); - auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy); + auto fail = analyzeProfitability(m.getMatchedChildren(), 1, patternDepth, + &strategy); if (fail) { continue; } - auto *loop = cast<ForInst>(m.first); + auto *loop = cast<ForInst>(m.getMatchedInstruction()); vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. - fail = vectorizeRootMatches(matches, &strategy); - assert(!fail && "top-level failure should not happen"); + fail = vectorizeRootMatch(m, &strategy); // TODO(ntv): some diagnostics. } } |

