diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-01-31 07:16:29 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:04:07 -0700 |
| commit | d4921f4a96a46e24c2fb76ca7feb66e6894395d2 (patch) | |
| tree | f9d08a3000a6db20ffe76a4c4d0fb7448d3a7b07 | |
| parent | 35200435e74533f8c703a1db5ebdbf6460c728c5 (diff) | |
| download | bcm5719-llvm-d4921f4a96a46e24c2fb76ca7feb66e6894395d2.tar.gz bcm5719-llvm-d4921f4a96a46e24c2fb76ca7feb66e6894395d2.zip | |
Address Performance issue in NestedMatcher
A performance issue was reported due to the usage of NestedMatcher in
ComposeAffineMaps. The main culprit was the ubiquitous copies that were
occuring when appending even a single element in `matchOne`.
This CL generally simplifies the implementation and removes one level of indirection by getting rid of
auxiliary storage as well as simplifying the API.
The users of the API are updated accordingly.
The implementation was tested on a heavily unrolled example with
ComposeAffineMaps and is now close in performance with an implementation based
on stateless InstWalker.
As a reminder, the whole ComposeAffineMaps pass is slated to disappear but the bug report was very useful as a stress test for NestedMatchers.
Lastly, the following cleanups reported by @aminim were addressed:
1. make NestedPatternContext scoped within runFunction rather than at the Pass level. This was caused by a previous misunderstanding of Pass lifetime;
2. use defensive assertions in the constructor of NestedPatternContext to make it clear a unique such locally scoped context is allowed to exist.
PiperOrigin-RevId: 231781279
| -rw-r--r-- | mlir/include/mlir/Analysis/NestedMatcher.h | 165 | ||||
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 14 | ||||
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 158 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 71 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 192 |
6 files changed, 291 insertions, 320 deletions
diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 161bb217a10..2a1c469348d 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -20,20 +20,19 @@ #include "mlir/IR/InstVisitor.h" #include "llvm/Support/Allocator.h" -#include <utility> namespace mlir { -struct NestedPatternStorage; -struct NestedMatchStorage; +struct NestedPattern; class Instruction; -/// An NestedPattern captures nested patterns. It is used in conjunction with -/// a scoped NestedPatternContext which is an llvm::BumPtrAllocator that -/// handles memory allocations efficiently and avoids ownership issues. +/// An NestedPattern captures nested patterns in the IR. +/// It is used in conjunction with a scoped NestedPatternContext which is an +/// llvm::BumpPtrAllocator that handles memory allocations efficiently and +/// avoids ownership issues. /// -/// In order to use NestedPatterns, first create a scoped context. When the -/// context goes out of scope, everything is freed. +/// In order to use NestedPatterns, first create a scoped context. +/// When the context goes out of scope, everything is freed. /// This design simplifies the API by avoiding references to the context and /// makes it clear that references to matchers must not escape. /// @@ -45,109 +44,145 @@ class Instruction; /// // do work on matches /// } // everything is freed /// - -/// Recursive abstraction for matching results. -/// Provides iteration over the Instruction* captured by a Matcher. /// -/// Implemented as a POD value-type with underlying storage pointer. -/// The underlying storage lives in a scoped bumper allocator whose lifetime -/// is managed by an RAII NestedPatternContext. -/// This is used by value everywhere. +/// Nested abstraction for matching results. +/// Provides access to the nested Instruction* captured by a Matcher. +/// +/// A NestedMatch contains an Instruction* and the children NestedMatch and is +/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose +/// lifetime is managed by an RAII NestedPatternContext. struct NestedMatch { - using EntryType = std::pair<Instruction *, NestedMatch>; - using iterator = EntryType *; - - static NestedMatch build(ArrayRef<NestedMatch::EntryType> elements = {}); + static NestedMatch build(Instruction *instruction, + ArrayRef<NestedMatch> nestedMatches); NestedMatch(const NestedMatch &) = default; NestedMatch &operator=(const NestedMatch &) = default; - explicit operator bool() { return !empty(); } + explicit operator bool() { return matchedInstruction != nullptr; } - iterator begin(); - iterator end(); - EntryType &front(); - EntryType &back(); - unsigned size() { return end() - begin(); } - unsigned empty() { return size() == 0; } + Instruction *getMatchedInstruction() { return matchedInstruction; } + ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; } private: friend class NestedPattern; friend class NestedPatternContext; - friend class NestedMatchStorage; /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); - NestedMatch(NestedMatchStorage *storage) : storage(storage){}; - - /// Copy the specified array of elements into memory managed by our bump - /// pointer allocator. The elements are all PODs by constructions. - static NestedMatch copyInto(ArrayRef<NestedMatch::EntryType> elements); + NestedMatch() = default; - /// POD payload. - NestedMatchStorage *storage; + /// Payload, holds a NestedMatch and all its children along this branch. + Instruction *matchedInstruction; + ArrayRef<NestedMatch> matchedChildren; }; -/// A NestedPattern is a special type of InstWalker that: +/// A NestedPattern is a nested InstWalker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); -/// 3. TODO(ntv) Optionally applies actions (lambda). +/// 3. TODO(ntv) optionally applies actions (lambda). /// -/// Implemented as a POD value-type with underlying storage pointer. -/// The underlying storage lives in a scoped bumper allocator whose lifetime -/// is managed by an RAII NestedPatternContext. -/// This should be used by value everywhere. +/// Nested patterns are meant to capture imperfectly nested loops while matching +/// properties over the whole loop nest. For instance, in vectorization we are +/// interested in capturing all the imperfectly nested loops of a certain type +/// and such that all the load and stores have certain access patterns along the +/// loops' induction variables). Such NestedMatches are first captured using the +/// `match` function and are later processed to analyze properties and apply +/// transformations in a non-greedy way. +/// +/// The NestedMatches captured in the IR can grow large, especially after +/// aggressive unrolling. As experience has shown, it is generally better to use +/// a plain InstWalker to match flat patterns but the current implementation is +/// competitive nonetheless. using FilterFunctionType = std::function<bool(const Instruction &)>; static bool defaultFilterFunction(const Instruction &) { return true; }; -struct NestedPattern : public InstWalker<NestedPattern> { +struct NestedPattern { NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested, FilterFunctionType filter = defaultFilterFunction); NestedPattern(const NestedPattern &) = default; NestedPattern &operator=(const NestedPattern &) = default; - /// Returns all the matches in `function`. - NestedMatch match(Function *function); + /// Returns all the top-level matches in `function`. + void match(Function *function, SmallVectorImpl<NestedMatch> *matches) { + State state(*this, matches); + state.walkPostOrder(function); + } - /// Returns all the matches nested under `instruction`. - NestedMatch match(Instruction *instruction); + /// Returns all the top-level matches in `inst`. + void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) { + State state(*this, matches); + state.walkPostOrder(inst); + } - unsigned getDepth(); + /// Returns the depth of the pattern. + unsigned getDepth() const; private: friend class NestedPatternContext; - friend InstWalker<NestedPattern>; + friend class NestedMatch; + friend class InstWalker<NestedPattern>; + friend struct State; + + /// Helper state that temporarily holds matches for the next level of nesting. + struct State : public InstWalker<State> { + State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches) + : pattern(pattern), matches(matches) {} + void visitForInst(ForInst *forInst) { pattern.matchOne(forInst, matches); } + void visitOperationInst(OperationInst *opInst) { + pattern.matchOne(opInst, matches); + } + + private: + NestedPattern &pattern; + SmallVectorImpl<NestedMatch> *matches; + }; /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); - Instruction::Kind getKind(); - ArrayRef<NestedPattern> getNestedPatterns(); - FilterFunctionType getFilterFunction(); - - void matchOne(Instruction *elem); - - void visitForInst(ForInst *forInst) { matchOne(forInst); } - void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } - - /// POD paylod. - /// Storage for the PatternMatcher. - NestedPatternStorage *storage; - - // By-value POD wrapper to underlying storage pointer. - NestedMatch matches; + /// Matches this pattern against a single `inst` and fills matches with the + /// result. + void matchOne(Instruction *inst, SmallVectorImpl<NestedMatch> *matches); + + /// Instruction kind matched by this pattern. + Instruction::Kind kind; + + /// Nested patterns to be matched. + ArrayRef<NestedPattern> nestedPatterns; + + /// Extra filter function to apply to prune patterns as the IR is walked. + FilterFunctionType filter; + + /// skip is an implementation detail needed so that we can implement match + /// without switching on the type of the Instruction. The idea is that a + /// NestedPattern first checks if it matches locally and then recursively + /// applies its nested matchers to its elem->nested. Since we want to rely on + /// the InstWalker impl rather than duplicate its the logic, we allow an + /// off-by-one traversal to account for the fact that we write: + /// + /// void match(Instruction *elem) { + /// for (auto &c : getNestedPatterns()) { + /// NestedPattern childPattern(...); + /// ^~~~ Needs off-by-one skip. + /// + Instruction *skip; }; /// RAII structure to transparently manage the bump allocator for -/// NestedPattern and NestedMatch classes. +/// NestedPattern and NestedMatch classes. This avoids passing a context to +/// all the API functions. struct NestedPatternContext { NestedPatternContext() { - NestedPattern::allocator() = &allocator; + assert(NestedMatch::allocator() == nullptr && + "Only a single NestedPatternContext is supported"); + assert(NestedPattern::allocator() == nullptr && + "Only a single NestedPatternContext is supported"); NestedMatch::allocator() = &allocator; + NestedPattern::allocator() = &allocator; } ~NestedPatternContext() { - NestedPattern::allocator() = nullptr; NestedMatch::allocator() = nullptr; + NestedPattern::allocator() = nullptr; } llvm::BumpPtrAllocator allocator; }; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 07c903a6613..7d88a3d9b9f 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -242,7 +242,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, // No vectorization across conditionals for now. auto conditionals = matcher::If(); auto *forInst = const_cast<ForInst *>(&loop); - auto conditionalsMatched = conditionals.match(forInst); + SmallVector<NestedMatch, 8> conditionalsMatched; + conditionals.match(forInst, &conditionalsMatched); if (!conditionalsMatched.empty()) { return false; } @@ -252,21 +253,24 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, auto &opInst = cast<OperationInst>(inst); return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>(); }); - auto regionsMatched = regions.match(forInst); + SmallVector<NestedMatch, 8> regionsMatched; + regions.match(forInst, ®ionsMatched); if (!regionsMatched.empty()) { return false; } auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); - auto vectorTransfersMatched = vectorTransfers.match(forInst); + SmallVector<NestedMatch, 8> vectorTransfersMatched; + vectorTransfers.match(forInst, &vectorTransfersMatched); if (!vectorTransfersMatched.empty()) { return false; } auto loadAndStores = matcher::Op(matcher::isLoadOrStore); - auto loadAndStoresMatched = loadAndStores.match(forInst); + SmallVector<NestedMatch, 8> loadAndStoresMatched; + loadAndStores.match(forInst, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { - auto *op = cast<OperationInst>(ls.first); + auto *op = cast<OperationInst>(ls.getMatchedInstruction()); auto load = op->dyn_cast<LoadOp>(); auto store = op->dyn_cast<StoreOp>(); // Only scalar types are considered vectorizable, all load/store must be diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 491a9bef1b9..43e3b332a58 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -20,93 +20,49 @@ #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" -namespace mlir { - -/// Underlying storage for NestedMatch. -struct NestedMatchStorage { - MutableArrayRef<NestedMatch::EntryType> matches; -}; - -/// Underlying storage for NestedPattern. -struct NestedPatternStorage { - NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c, - FilterFunctionType filter, Instruction *skip) - : kind(k), nestedPatterns(c), filter(filter), skip(skip) {} - - Instruction::Kind kind; - ArrayRef<NestedPattern> nestedPatterns; - FilterFunctionType filter; - /// skip is needed so that we can implement match without switching on the - /// type of the Instruction. - /// The idea is that a NestedPattern first checks if it matches locally - /// and then recursively applies its nested matchers to its elem->nested. - /// Since we want to rely on the InstWalker impl rather than duplicate its - /// the logic, we allow an off-by-one traversal to account for the fact that - /// we write: - /// - /// void match(Instruction *elem) { - /// for (auto &c : getNestedPatterns()) { - /// NestedPattern childPattern(...); - /// ^~~~ Needs off-by-one skip. - /// - Instruction *skip; -}; - -} // end namespace mlir - using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; + thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } -NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) { - auto *matches = - allocator()->Allocate<NestedMatch::EntryType>(elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), matches); - auto *storage = allocator()->Allocate<NestedMatchStorage>(); - new (storage) NestedMatchStorage(); - storage->matches = - MutableArrayRef<NestedMatch::EntryType>(matches, elements.size()); +NestedMatch NestedMatch::build(Instruction *instruction, + ArrayRef<NestedMatch> nestedMatches) { auto *result = allocator()->Allocate<NestedMatch>(); - new (result) NestedMatch(storage); + auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size()); + std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); + new (result) NestedMatch(); + result->matchedInstruction = instruction; + result->matchedChildren = + ArrayRef<NestedMatch>(children, nestedMatches.size()); return *result; } -NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); } -NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); } -NestedMatch::EntryType &NestedMatch::front() { - return *storage->matches.begin(); -} -NestedMatch::EntryType &NestedMatch::back() { - return *(storage->matches.begin() + size() - 1); -} - -/// Calls walk on `function`. -NestedMatch NestedPattern::match(Function *function) { - assert(!matches && "NestedPattern already matched!"); - this->walkPostOrder(function); - return matches; +llvm::BumpPtrAllocator *&NestedPattern::allocator() { + thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; } -/// Calls walk on `instruction`. -NestedMatch NestedPattern::match(Instruction *instruction) { - assert(!matches && "NestedPattern already matched!"); - this->walkPostOrder(instruction); - return matches; +NestedPattern::NestedPattern(Instruction::Kind k, + ArrayRef<NestedPattern> nested, + FilterFunctionType filter) + : kind(k), nestedPatterns(ArrayRef<NestedPattern>(nested)), filter(filter) { + auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); + std::uninitialized_copy(nested.begin(), nested.end(), newNested); + nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); } -unsigned NestedPattern::getDepth() { - auto nested = getNestedPatterns(); - if (nested.empty()) { +unsigned NestedPattern::getDepth() const { + if (nestedPatterns.empty()) { return 1; } unsigned depth = 0; - for (auto c : nested) { + for (auto &c : nestedPatterns) { depth = std::max(depth, c.getDepth()); } return depth + 1; @@ -122,69 +78,39 @@ unsigned NestedPattern::getDepth() { /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void NestedPattern::matchOne(Instruction *elem) { - if (storage->skip == elem) { +void NestedPattern::matchOne(Instruction *inst, + SmallVectorImpl<NestedMatch> *matches) { + if (skip == inst) { return; } // Structural filter - if (elem->getKind() != getKind()) { + if (inst->getKind() != kind) { return; } // Local custom filter function - if (!getFilterFunction()(*elem)) { + if (!filter(*inst)) { return; } - SmallVector<NestedMatch::EntryType, 8> nestedEntries; - for (auto c : getNestedPatterns()) { - /// We create a new nestedPattern here because a matcher holds its - /// results. So we concretely need multiple copies of a given matcher, one - /// for each matching result. - NestedPattern nestedPattern(c); + if (nestedPatterns.empty()) { + SmallVector<NestedMatch, 8> nestedMatches; + matches->push_back(NestedMatch::build(inst, nestedMatches)); + return; + } + // Take a copy of each nested pattern so we can match it. + for (auto nestedPattern : nestedPatterns) { + SmallVector<NestedMatch, 8> nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. - nestedPattern.storage->skip = elem; - nestedPattern.walkPostOrder(elem); - if (!nestedPattern.matches) { + nestedPattern.skip = inst; + nestedPattern.match(inst, &nestedMatches); + // If we could not match even one of the specified nestedPattern, early exit + // as this whole branch is not a match. + if (nestedMatches.empty()) { return; } - for (auto m : nestedPattern.matches) { - nestedEntries.push_back(m); - } + matches->push_back(NestedMatch::build(inst, nestedMatches)); } - - SmallVector<NestedMatch::EntryType, 8> newEntries( - matches.storage->matches.begin(), matches.storage->matches.end()); - newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries))); - matches = NestedMatch::build(newEntries); -} - -llvm::BumpPtrAllocator *&NestedPattern::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; - return allocator; -} - -NestedPattern::NestedPattern(Instruction::Kind k, - ArrayRef<NestedPattern> nested, - FilterFunctionType filter) - : storage(allocator()->Allocate<NestedPatternStorage>()), - matches(NestedMatch::build({})) { - auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size()); - std::uninitialized_copy(nested.begin(), nested.end(), newChildren); - // Initialize with placement new. - new (storage) NestedPatternStorage( - k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter, - nullptr /* skip */); -} - -Instruction::Kind NestedPattern::getKind() { return storage->kind; } - -ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() { - return storage->nestedPatterns; -} - -FilterFunctionType NestedPattern::getFilterFunction() { - return storage->filter; } static bool isAffineIfOp(const Instruction &inst) { diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2744b1d624c..432ad1f39b8 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -201,9 +201,6 @@ struct MaterializeVectorsPass : public FunctionPass { PassResult runOnFunction(Function *f) override; - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext mlContext; - static char passID; }; @@ -744,6 +741,9 @@ static bool materialize(Function *f, } PassResult MaterializeVectorsPass::runOnFunction(Function *f) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + // TODO(ntv): Check to see if this supports arbitrary top-level code. if (f->getBlocks().size() != 1) return success(); @@ -768,10 +768,11 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { return matcher::operatesOnSuperVectors(opInst, subVectorType); }; auto pat = Op(filter); - auto matches = pat.match(f); + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); SetVector<OperationInst *> terminators; for (auto m : matches) { - terminators.insert(cast<OperationInst>(m.first)); + terminators.insert(cast<OperationInst>(m.getMatchedInstruction())); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index a01b8fdf216..a9b9752ef51 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -30,6 +30,7 @@ #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Passes.h" +#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -94,9 +95,6 @@ struct VectorizerTestPass : public FunctionPass { void testComposeMaps(Function *f); void testNormalizeMaps(Function *f); - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext MLContext; - static char passID; }; @@ -128,9 +126,10 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { return true; }; auto pat = Op(filter); - auto matches = pat.match(f); + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.first); + auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -153,7 +152,7 @@ static std::string toString(Instruction *inst) { return res; } -static NestedMatch matchTestSlicingOps(Function *f) { +static NestedPattern patternTestSlicingOps() { // Just use a custom op name for this test, it makes life easier. constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; @@ -163,17 +162,18 @@ static NestedMatch matchTestSlicingOps(Function *f) { const auto &opInst = cast<OperationInst>(inst); return opInst.getName().getStringRef() == kTestSlicingOpName; }; - auto pat = Op(filter); - return pat.match(f); + return Op(filter); } void VectorizerTestPass::testBackwardSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Instruction *> backwardSlice; - getBackwardSlice(m.first, &backwardSlice); + getBackwardSlice(m.getMatchedInstruction(), &backwardSlice); auto strs = map(toString, backwardSlice); - outs() << "\nmatched: " << *m.first << " backward static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() + << " backward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -181,12 +181,14 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { } void VectorizerTestPass::testForwardSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Instruction *> forwardSlice; - getForwardSlice(m.first, &forwardSlice); + getForwardSlice(m.getMatchedInstruction(), &forwardSlice); auto strs = map(toString, forwardSlice); - outs() << "\nmatched: " << *m.first << " forward static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() + << " forward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -194,11 +196,12 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { } void VectorizerTestPass::testSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector<Instruction *> staticSlice = getSlice(m.first); + SetVector<Instruction *> staticSlice = getSlice(m.getMatchedInstruction()); auto strs = map(toString, staticSlice); - outs() << "\nmatched: " << *m.first << " static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() << " static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -214,12 +217,12 @@ static bool customOpWithAffineMapAttribute(const Instruction &inst) { void VectorizerTestPass::testComposeMaps(Function *f) { using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); - auto matches = pattern.match(f); + SmallVector<NestedMatch, 8> matches; + pattern.match(f, &matches); SmallVector<AffineMap, 4> maps; maps.reserve(matches.size()); - std::reverse(matches.begin(), matches.end()); - for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.first); + for (auto m : llvm::reverse(matches)) { + auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast<AffineMapAttr>() .getValue(); @@ -248,29 +251,31 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); - auto toErase = pattern.match(f); - std::reverse(toErase.begin(), toErase.end()); + SmallVector<NestedMatch, 8> toErase; + pattern.match(f, &toErase); { // Compose maps. auto pattern = Op(singleResultAffineApplyOpWithoutUses); - for (auto m : pattern.match(f)) { - auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>(); - FuncBuilder b(m.first); - - using ValueTy = decltype(*(app->getOperands().begin())); - SmallVector<Value *, 8> operands = - functional::map([](ValueTy v) { return static_cast<Value *>(v); }, - app->getOperands().begin(), app->getOperands().end()); + SmallVector<NestedMatch, 8> matches; + pattern.match(f, &matches); + for (auto m : matches) { + auto app = + cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>(); + FuncBuilder b(m.getMatchedInstruction()); + SmallVector<Value *, 8> operands(app->getOperands()); makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); } } // We should now be able to erase everything in reverse order in this test. - for (auto m : toErase) { - m.first->erase(); + for (auto m : llvm::reverse(toErase)) { + m.getMatchedInstruction()->erase(); } } PassResult VectorizerTestPass::runOnFunction(Function *f) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + // Only support single block functions at this point. if (f->getBlocks().size() != 1) return success(); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index cfde1ecf0a8..73893599d17 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -655,9 +655,6 @@ struct Vectorize : public FunctionPass { PassResult runOnFunction(Function *f) override; - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext MLContext; - static char passID; }; @@ -703,13 +700,13 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, /// 3. account for impact of vectorization on maximal loop fusion. /// Then we can quantify the above to build a cost model and search over /// strategies. -static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern, - unsigned patternDepth, +static bool analyzeProfitability(ArrayRef<NestedMatch> matches, + unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForInst>(m.first); - bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth, - strategy); + auto *loop = cast<ForInst>(m.getMatchedInstruction()); + bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, + patternDepth, strategy); if (fail) { return fail; } @@ -875,9 +872,10 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, state->terminators.count(opInst) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); - auto matches = loadAndStores.match(loop); - for (auto ls : matches) { - auto *opInst = cast<OperationInst>(ls.first); + SmallVector<NestedMatch, 8> loadAndStoresMatches; + loadAndStores.match(loop, &loadAndStoresMatches); + for (auto ls : loadAndStoresMatches) { + auto *opInst = cast<OperationInst>(ls.getMatchedInstruction()); auto load = opInst->dyn_cast<LoadOp>(); auto store = opInst->dyn_cast<StoreOp>(); LLVM_DEBUG(opInst->print(dbgs())); @@ -907,15 +905,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { } /// Forward-declaration. -static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state); +static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, + VectorizationState *state); /// Apply vectorization of `loop` according to `state`. This is only triggered /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. -static bool doVectorize(NestedMatch::EntryType oneMatch, - VectorizationState *state) { - ForInst *loop = cast<ForInst>(oneMatch.first); - NestedMatch childrenMatches = oneMatch.second; +static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { + ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction()); + auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. auto fail = vectorizeNonRoot(childrenMatches, state); @@ -949,7 +947,8 @@ static bool doVectorize(NestedMatch::EntryType oneMatch, /// Non-root pattern iterates over the matches at this level, calls doVectorize /// and exits early if anything below fails. -static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) { +static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, + VectorizationState *state) { for (auto m : matches) { auto fail = doVectorize(m, state); if (fail) { @@ -1185,99 +1184,100 @@ static bool vectorizeOperations(VectorizationState *state) { /// The root match thus needs to maintain a clone for handling failure. /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. -static bool vectorizeRootMatches(NestedMatch matches, - VectorizationStrategy *strategy) { - for (auto m : matches) { - auto *loop = cast<ForInst>(m.first); - VectorizationState state; - state.strategy = strategy; - - // Since patterns are recursive, they can very well intersect. - // Since we do not want a fully greedy strategy in general, we decouple - // pattern matching, from profitability analysis, from application. - // As a consequence we must check that each root pattern is still - // vectorizable. If a pattern is not vectorizable anymore, we just skip it. - // TODO(ntv): implement a non-greedy profitability analysis that keeps only - // non-intersecting patterns. - if (!isVectorizableLoop(*loop)) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); - continue; - } - FuncBuilder builder(loop); // builder to insert in place of loop - ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); - auto fail = doVectorize(m, &state); - /// Sets up error handling for this root loop. This is how the root match - /// maintains a clone for handling failure and restores the proper state via - /// RAII. - ScopeGuard sg2([&fail, loop, clonedLoop]() { - if (fail) { - loop->getInductionVar()->replaceAllUsesWith( - clonedLoop->getInductionVar()); - loop->erase(); - } else { - clonedLoop->erase(); - } - }); +static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { + auto *loop = cast<ForInst>(m.getMatchedInstruction()); + VectorizationState state; + state.strategy = strategy; + + // Since patterns are recursive, they can very well intersect. + // Since we do not want a fully greedy strategy in general, we decouple + // pattern matching, from profitability analysis, from application. + // As a consequence we must check that each root pattern is still + // vectorizable. If a pattern is not vectorizable anymore, we just skip it. + // TODO(ntv): implement a non-greedy profitability analysis that keeps only + // non-intersecting patterns. + if (!isVectorizableLoop(*loop)) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); + return true; + } + FuncBuilder builder(loop); // builder to insert in place of loop + ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); + auto fail = doVectorize(m, &state); + /// Sets up error handling for this root loop. This is how the root match + /// maintains a clone for handling failure and restores the proper state via + /// RAII. + ScopeGuard sg2([&fail, loop, clonedLoop]() { if (fail) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize"); - continue; + loop->getInductionVar()->replaceAllUsesWith( + clonedLoop->getInductionVar()); + loop->erase(); + } else { + clonedLoop->erase(); } + }); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize"); + return true; + } - // Form the root operationsthat have been set in the replacementMap. - // For now, these roots are the loads for which vector_transfer_read - // operations have been inserted. - auto getDefiningInst = [](const Value *val) { - return const_cast<Value *>(val)->getDefiningInst(); - }; - using ReferenceTy = decltype(*(state.replacementMap.begin())); - auto getKey = [](ReferenceTy it) { return it.first; }; - auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); - - // Vectorize the root operations and everything reached by use-def chains - // except the terminators (store instructions) that need to be - // post-processed separately. - fail = vectorizeOperations(&state); - if (fail) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); - continue; - } + // Form the root operationsthat have been set in the replacementMap. + // For now, these roots are the loads for which vector_transfer_read + // operations have been inserted. + auto getDefiningInst = [](const Value *val) { + return const_cast<Value *>(val)->getDefiningInst(); + }; + using ReferenceTy = decltype(*(state.replacementMap.begin())); + auto getKey = [](ReferenceTy it) { return it.first; }; + auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); + + // Vectorize the root operations and everything reached by use-def chains + // except the terminators (store instructions) that need to be + // post-processed separately. + fail = vectorizeOperations(&state); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); + return true; + } - // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { - if (fail) { - return; - } - FuncBuilder b(inst); - auto *res = vectorizeOneOperationInst(&b, inst, &state); - if (res == nullptr) { - fail = true; - } - }; - apply(vectorizeOrFail, state.terminators); + // Finally, vectorize the terminators. If anything fails to vectorize, skip. + auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { if (fail) { - LLVM_DEBUG( - dbgs() << "\n[early-vect]+++++ failed to vectorize terminators"); - continue; - } else { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); + return; } - - // Finish this vectorization pattern. - state.finishVectorizationPattern(); + FuncBuilder b(inst); + auto *res = vectorizeOneOperationInst(&b, inst, &state); + if (res == nullptr) { + fail = true; + } + }; + apply(vectorizeOrFail, state.terminators); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminators"); + return true; + } else { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); } + + // Finish this vectorization pattern. + state.finishVectorizationPattern(); return false; } /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. PassResult Vectorize::runOnFunction(Function *f) { - for (auto pat : makePatterns()) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + + for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); LLVM_DEBUG(f->print(dbgs())); unsigned patternDepth = pat.getDepth(); - auto matches = pat.match(f); + + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { @@ -1285,16 +1285,16 @@ PassResult Vectorize::runOnFunction(Function *f) { // TODO(ntv): depending on profitability, elect to reduce the vector size. strategy.vectorSizes.assign(clVirtualVectorSize.begin(), clVirtualVectorSize.end()); - auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy); + auto fail = analyzeProfitability(m.getMatchedChildren(), 1, patternDepth, + &strategy); if (fail) { continue; } - auto *loop = cast<ForInst>(m.first); + auto *loop = cast<ForInst>(m.getMatchedInstruction()); vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. - fail = vectorizeRootMatches(matches, &strategy); - assert(!fail && "top-level failure should not happen"); + fail = vectorizeRootMatch(m, &strategy); // TODO(ntv): some diagnostics. } } |

