diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Analysis/NestedMatcher.h (renamed from mlir/include/mlir/Analysis/MLFunctionMatcher.h) | 131 | ||||
| -rw-r--r-- | mlir/include/mlir/EDSC/MLIREmitter.h | 2 | ||||
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Analysis/MLFunctionMatcher.cpp | 263 | ||||
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 240 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ComposeAffineMaps.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 32 |
10 files changed, 334 insertions, 355 deletions
diff --git a/mlir/include/mlir/Analysis/MLFunctionMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 5de1f6d729b..c205d55488e 100644 --- a/mlir/include/mlir/Analysis/MLFunctionMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -1,4 +1,4 @@ -//===- MLFunctionMacher.h - Recursive matcher for MLFunction ----*- C++ -*-===// +//===- NestedMacher.h - Nested matcher for MLFunction -----------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // @@ -24,22 +24,22 @@ namespace mlir { -struct MLFunctionMatcherStorage; -struct MLFunctionMatchesStorage; +struct NestedPatternStorage; +struct NestedMatchStorage; class Instruction; -/// An MLFunctionMatcher is a recursive matcher that captures nested patterns in -/// an ML Function. It is used in conjunction with a scoped -/// MLFunctionMatcherContext that handles the memory allocations efficiently. +/// An NestedPattern captures nested patterns. It is used in conjunction with +/// a scoped NestedPatternContext which is an llvm::BumPtrAllocator that +/// handles memory allocations efficiently and avoids ownership issues. /// -/// In order to use MLFunctionMatchers creates a scoped context and uses -/// matchers. When the context goes out of scope, everything is freed. +/// In order to use NestedPatterns, first create a scoped context. When the +/// context goes out of scope, everything is freed. /// This design simplifies the API by avoiding references to the context and /// makes it clear that references to matchers must not escape. /// /// Example: /// { -/// MLFunctionMatcherContext context; +/// NestedPatternContext context; /// auto gemmLike = Doall(Doall(Red(LoadStores()))); /// auto matches = gemmLike.match(f); /// // do work on matches @@ -51,15 +51,17 @@ class Instruction; /// /// Implemented as a POD value-type with underlying storage pointer. /// The underlying storage lives in a scoped bumper allocator whose lifetime -/// is managed by an RAII MLFunctionMatcherContext. -/// This should be used by value everywhere. -struct MLFunctionMatches { - using EntryType = std::pair<Instruction *, MLFunctionMatches>; +/// is managed by an RAII NestedPatternContext. +/// This is used by value everywhere. +struct NestedMatch { + using EntryType = std::pair<Instruction *, NestedMatch>; using iterator = EntryType *; - MLFunctionMatches() : storage(nullptr) {} + static NestedMatch build(ArrayRef<NestedMatch::EntryType> elements = {}); + NestedMatch(const NestedMatch &) = default; + NestedMatch &operator=(const NestedMatch &) = default; - explicit operator bool() { return storage; } + explicit operator bool() { return !empty(); } iterator begin(); iterator end(); @@ -68,20 +70,25 @@ struct MLFunctionMatches { unsigned size() { return end() - begin(); } unsigned empty() { return size() == 0; } - /// Appends the pair <inst, children> to the current matches. - void append(Instruction *inst, MLFunctionMatches children); - private: - friend class MLFunctionMatcher; - friend class MLFunctionMatcherContext; + friend class NestedPattern; + friend class NestedPatternContext; + friend class NestedMatchStorage; - /// Underlying global bump allocator managed by an MLFunctionMatcherContext. + /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); - MLFunctionMatchesStorage *storage; + NestedMatch(NestedMatchStorage *storage) : storage(storage){}; + + /// Copy the specified array of elements into memory managed by our bump + /// pointer allocator. The elements are all PODs by constructions. + static NestedMatch copyInto(ArrayRef<NestedMatch::EntryType> elements); + + /// POD payload. + NestedMatchStorage *storage; }; -/// A MLFunctionMatcher is a special type of InstWalker that: +/// A NestedPattern is a special type of InstWalker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); @@ -89,78 +96,76 @@ private: /// /// Implemented as a POD value-type with underlying storage pointer. /// The underlying storage lives in a scoped bumper allocator whose lifetime -/// is managed by an RAII MLFunctionMatcherContext. +/// is managed by an RAII NestedPatternContext. /// This should be used by value everywhere. using FilterFunctionType = std::function<bool(const Instruction &)>; static bool defaultFilterFunction(const Instruction &) { return true; }; -struct MLFunctionMatcher : public InstWalker<MLFunctionMatcher> { - MLFunctionMatcher(Instruction::Kind k, MLFunctionMatcher child, - FilterFunctionType filter = defaultFilterFunction); - MLFunctionMatcher(Instruction::Kind k, - MutableArrayRef<MLFunctionMatcher> children, - FilterFunctionType filter = defaultFilterFunction); +struct NestedPattern : public InstWalker<NestedPattern> { + NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested, + FilterFunctionType filter = defaultFilterFunction); + NestedPattern(const NestedPattern &) = default; + NestedPattern &operator=(const NestedPattern &) = default; /// Returns all the matches in `function`. - MLFunctionMatches match(Function *function); + NestedMatch match(Function *function); /// Returns all the matches nested under `instruction`. - MLFunctionMatches match(Instruction *instruction); + NestedMatch match(Instruction *instruction); unsigned getDepth(); private: - friend class MLFunctionMatcherContext; - friend InstWalker<MLFunctionMatcher>; + friend class NestedPatternContext; + friend InstWalker<NestedPattern>; + + /// Underlying global bump allocator managed by a NestedPatternContext. + static llvm::BumpPtrAllocator *&allocator(); Instruction::Kind getKind(); - MutableArrayRef<MLFunctionMatcher> getChildrenMLFunctionMatchers(); + ArrayRef<NestedPattern> getNestedPatterns(); FilterFunctionType getFilterFunction(); - MLFunctionMatcher forkMLFunctionMatcherAt(MLFunctionMatcher tmpl, - Instruction *inst); - void matchOne(Instruction *elem); void visitForInst(ForInst *forInst) { matchOne(forInst); } void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } - /// Underlying global bump allocator managed by an MLFunctionMatcherContext. - static llvm::BumpPtrAllocator *&allocator(); - - MLFunctionMatcherStorage *storage; + /// POD paylod. + /// Storage for the PatternMatcher. + NestedPatternStorage *storage; // By-value POD wrapper to underlying storage pointer. - MLFunctionMatches matches; + NestedMatch matches; }; /// RAII structure to transparently manage the bump allocator for -/// MLFunctionMatcher and MLFunctionMatches classes. -struct MLFunctionMatcherContext { - MLFunctionMatcherContext() { - MLFunctionMatcher::allocator() = &allocator; - MLFunctionMatches::allocator() = &allocator; +/// NestedPattern and NestedMatch classes. +struct NestedPatternContext { + NestedPatternContext() { + NestedPattern::allocator() = &allocator; + NestedMatch::allocator() = &allocator; } - ~MLFunctionMatcherContext() { - MLFunctionMatcher::allocator() = nullptr; - MLFunctionMatches::allocator() = nullptr; + ~NestedPatternContext() { + NestedPattern::allocator() = nullptr; + NestedMatch::allocator() = nullptr; } llvm::BumpPtrAllocator allocator; }; namespace matcher { -// Syntactic sugar MLFunctionMatcher builder functions. -MLFunctionMatcher Op(FilterFunctionType filter = defaultFilterFunction); -MLFunctionMatcher If(MLFunctionMatcher child); -MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child); -MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children = {}); -MLFunctionMatcher If(FilterFunctionType filter, - MutableArrayRef<MLFunctionMatcher> children = {}); -MLFunctionMatcher For(MLFunctionMatcher child); -MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child); -MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children = {}); -MLFunctionMatcher For(FilterFunctionType filter, - MutableArrayRef<MLFunctionMatcher> children = {}); +// Syntactic sugar NestedPattern builder functions. +NestedPattern Op(FilterFunctionType filter = defaultFilterFunction); +NestedPattern If(NestedPattern child); +NestedPattern If(FilterFunctionType filter, NestedPattern child); +NestedPattern If(ArrayRef<NestedPattern> nested = {}); +NestedPattern If(FilterFunctionType filter, + ArrayRef<NestedPattern> nested = {}); +NestedPattern For(NestedPattern child); +NestedPattern For(FilterFunctionType filter, NestedPattern child); +NestedPattern For(ArrayRef<NestedPattern> nested = {}); +NestedPattern For(FilterFunctionType filter, + ArrayRef<NestedPattern> nested = {}); bool isParallelLoop(const Instruction &inst); bool isReductionLoop(const Instruction &inst); diff --git a/mlir/include/mlir/EDSC/MLIREmitter.h b/mlir/include/mlir/EDSC/MLIREmitter.h index a696914eee2..fbd5a544a30 100644 --- a/mlir/include/mlir/EDSC/MLIREmitter.h +++ b/mlir/include/mlir/EDSC/MLIREmitter.h @@ -23,7 +23,7 @@ // generally designed to be automatically generated from various IR dialects in // the future. // The implementation is supported by a lightweight by-value abstraction on a -// scoped BumpAllocator with similarities to AffineExpr and MLFunctionMatcher. +// scoped BumpAllocator with similarities to AffineExpr and NestedPattern. // //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index b66b665c563..b154ebab105 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -23,7 +23,7 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" -#include "mlir/Analysis/MLFunctionMatcher.h" +#include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp deleted file mode 100644 index f2bbcd2a566..00000000000 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ /dev/null @@ -1,263 +0,0 @@ -//===- MLFunctionMatcher.cpp - MLFunctionMatcher Impl ----------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/Analysis/MLFunctionMatcher.h" -#include "mlir/StandardOps/StandardOps.h" - -#include "llvm/Support/Allocator.h" - -namespace mlir { - -/// Underlying storage for MLFunctionMatches. -struct MLFunctionMatchesStorage { - MLFunctionMatchesStorage(MLFunctionMatches::EntryType e) : matches({e}) {} - - SmallVector<MLFunctionMatches::EntryType, 8> matches; -}; - -/// Underlying storage for MLFunctionMatcher. -struct MLFunctionMatcherStorage { - MLFunctionMatcherStorage(Instruction::Kind k, - MutableArrayRef<MLFunctionMatcher> c, - FilterFunctionType filter, Instruction *skip) - : kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter), - skip(skip) {} - - Instruction::Kind kind; - SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers; - FilterFunctionType filter; - /// skip is needed so that we can implement match without switching on the - /// type of the Instruction. - /// The idea is that a MLFunctionMatcher first checks if it matches locally - /// and then recursively applies its children matchers to its elem->children. - /// Since we want to rely on the InstWalker impl rather than duplicate its - /// the logic, we allow an off-by-one traversal to account for the fact that - /// we write: - /// - /// void match(Instruction *elem) { - /// for (auto &c : getChildrenMLFunctionMatchers()) { - /// MLFunctionMatcher childMLFunctionMatcher(...); - /// ^~~~ Needs off-by-one skip. - /// - Instruction *skip; -}; - -} // end namespace mlir - -using namespace mlir; - -llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; - return allocator; -} - -void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) { - if (!storage) { - storage = allocator()->Allocate<MLFunctionMatchesStorage>(); - new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children)); - } else { - storage->matches.push_back(std::make_pair(inst, children)); - } -} -MLFunctionMatches::iterator MLFunctionMatches::begin() { - return storage ? storage->matches.begin() : nullptr; -} -MLFunctionMatches::iterator MLFunctionMatches::end() { - return storage ? storage->matches.end() : nullptr; -} -MLFunctionMatches::EntryType &MLFunctionMatches::front() { - assert(storage && "null storage"); - return *storage->matches.begin(); -} -MLFunctionMatches::EntryType &MLFunctionMatches::back() { - assert(storage && "null storage"); - return *(storage->matches.begin() + size() - 1); -} -/// Return the combination of multiple MLFunctionMatches as a new object. -static MLFunctionMatches combine(ArrayRef<MLFunctionMatches> matches) { - MLFunctionMatches res; - for (auto s : matches) { - for (auto ss : s) { - res.append(ss.first, ss.second); - } - } - return res; -} - -/// Calls walk on `function`. -MLFunctionMatches MLFunctionMatcher::match(Function *function) { - assert(!matches && "MLFunctionMatcher already matched!"); - this->walkPostOrder(function); - return matches; -} - -/// Calls walk on `instruction`. -MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) { - assert(!matches && "MLFunctionMatcher already matched!"); - this->walkPostOrder(instruction); - return matches; -} - -unsigned MLFunctionMatcher::getDepth() { - auto children = getChildrenMLFunctionMatchers(); - if (children.empty()) { - return 1; - } - unsigned depth = 0; - for (auto &c : children) { - depth = std::max(depth, c.getDepth()); - } - return depth + 1; -} - -/// Matches a single instruction in the following way: -/// 1. checks the kind of instruction against the matcher, if different then -/// there is no match; -/// 2. calls the customizable filter function to refine the single instruction -/// match with extra semantic constraints; -/// 3. if all is good, recursivey matches the children patterns; -/// 4. if all children match then the single instruction matches too and is -/// appended to the list of matches; -/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will -/// want to traverse in post-order DFS to avoid invalidating iterators. -void MLFunctionMatcher::matchOne(Instruction *elem) { - if (storage->skip == elem) { - return; - } - // Structural filter - if (elem->getKind() != getKind()) { - return; - } - // Local custom filter function - if (!getFilterFunction()(*elem)) { - return; - } - SmallVector<MLFunctionMatches, 8> childrenMLFunctionMatches; - for (auto &c : getChildrenMLFunctionMatchers()) { - /// We create a new childMLFunctionMatcher here because a matcher holds its - /// results. So we concretely need multiple copies of a given matcher, one - /// for each matching result. - MLFunctionMatcher childMLFunctionMatcher = forkMLFunctionMatcherAt(c, elem); - childMLFunctionMatcher.walkPostOrder(elem); - if (!childMLFunctionMatcher.matches) { - return; - } - childrenMLFunctionMatches.push_back(childMLFunctionMatcher.matches); - } - matches.append(elem, combine(childrenMLFunctionMatches)); -} - -llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; - return allocator; -} - -MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k, - MLFunctionMatcher child, - FilterFunctionType filter) - : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) { - // Initialize with placement new. - new (storage) - MLFunctionMatcherStorage(k, {child}, filter, nullptr /* skip */); -} - -MLFunctionMatcher::MLFunctionMatcher( - Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children, - FilterFunctionType filter) - : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) { - // Initialize with placement new. - new (storage) - MLFunctionMatcherStorage(k, children, filter, nullptr /* skip */); -} - -MLFunctionMatcher -MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl, - Instruction *inst) { - MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(), - tmpl.getFilterFunction()); - res.storage->skip = inst; - return res; -} - -Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; } - -MutableArrayRef<MLFunctionMatcher> -MLFunctionMatcher::getChildrenMLFunctionMatchers() { - return storage->childrenMLFunctionMatchers; -} - -FilterFunctionType MLFunctionMatcher::getFilterFunction() { - return storage->filter; -} - -namespace mlir { -namespace matcher { - -MLFunctionMatcher Op(FilterFunctionType filter) { - return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter); -} - -MLFunctionMatcher If(MLFunctionMatcher child) { - return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction); -} -MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Instruction::Kind::If, child, filter); -} -MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Instruction::Kind::If, children, - defaultFilterFunction); -} -MLFunctionMatcher If(FilterFunctionType filter, - MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Instruction::Kind::If, children, filter); -} - -MLFunctionMatcher For(MLFunctionMatcher child) { - return MLFunctionMatcher(Instruction::Kind::For, child, - defaultFilterFunction); -} -MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Instruction::Kind::For, child, filter); -} -MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Instruction::Kind::For, children, - defaultFilterFunction); -} -MLFunctionMatcher For(FilterFunctionType filter, - MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Instruction::Kind::For, children, filter); -} - -// TODO(ntv): parallel annotation on loops. -bool isParallelLoop(const Instruction &inst) { - const auto *loop = cast<ForInst>(&inst); - return (void *)loop || true; // loop->isParallel(); -}; - -// TODO(ntv): reduction annotation on loops. -bool isReductionLoop(const Instruction &inst) { - const auto *loop = cast<ForInst>(&inst); - return (void *)loop || true; // loop->isReduction(); -}; - -bool isLoadOrStore(const Instruction &inst) { - const auto *opInst = dyn_cast<OperationInst>(&inst); - return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()); -}; - -} // end namespace matcher -} // end namespace mlir diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp new file mode 100644 index 00000000000..4f32e9b22f4 --- /dev/null +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -0,0 +1,240 @@ +//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Analysis/NestedMatcher.h" +#include "mlir/StandardOps/StandardOps.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { + +/// Underlying storage for NestedMatch. +struct NestedMatchStorage { + MutableArrayRef<NestedMatch::EntryType> matches; +}; + +/// Underlying storage for NestedPattern. +struct NestedPatternStorage { + NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c, + FilterFunctionType filter, Instruction *skip) + : kind(k), nestedPatterns(c), filter(filter), skip(skip) {} + + Instruction::Kind kind; + ArrayRef<NestedPattern> nestedPatterns; + FilterFunctionType filter; + /// skip is needed so that we can implement match without switching on the + /// type of the Instruction. + /// The idea is that a NestedPattern first checks if it matches locally + /// and then recursively applies its nested matchers to its elem->nested. + /// Since we want to rely on the InstWalker impl rather than duplicate its + /// the logic, we allow an off-by-one traversal to account for the fact that + /// we write: + /// + /// void match(Instruction *elem) { + /// for (auto &c : getNestedPatterns()) { + /// NestedPattern childPattern(...); + /// ^~~~ Needs off-by-one skip. + /// + Instruction *skip; +}; + +} // end namespace mlir + +using namespace mlir; + +llvm::BumpPtrAllocator *&NestedMatch::allocator() { + static thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; +} + +NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) { + auto *matches = + allocator()->Allocate<NestedMatch::EntryType>(elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), matches); + auto *storage = allocator()->Allocate<NestedMatchStorage>(); + new (storage) NestedMatchStorage(); + storage->matches = + MutableArrayRef<NestedMatch::EntryType>(matches, elements.size()); + auto *result = allocator()->Allocate<NestedMatch>(); + new (result) NestedMatch(storage); + return *result; +} + +NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); } +NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); } +NestedMatch::EntryType &NestedMatch::front() { + return *storage->matches.begin(); +} +NestedMatch::EntryType &NestedMatch::back() { + return *(storage->matches.begin() + size() - 1); +} + +/// Calls walk on `function`. +NestedMatch NestedPattern::match(Function *function) { + assert(!matches && "NestedPattern already matched!"); + this->walkPostOrder(function); + return matches; +} + +/// Calls walk on `instruction`. +NestedMatch NestedPattern::match(Instruction *instruction) { + assert(!matches && "NestedPattern already matched!"); + this->walkPostOrder(instruction); + return matches; +} + +unsigned NestedPattern::getDepth() { + auto nested = getNestedPatterns(); + if (nested.empty()) { + return 1; + } + unsigned depth = 0; + for (auto c : nested) { + depth = std::max(depth, c.getDepth()); + } + return depth + 1; +} + +/// Matches a single instruction in the following way: +/// 1. checks the kind of instruction against the matcher, if different then +/// there is no match; +/// 2. calls the customizable filter function to refine the single instruction +/// match with extra semantic constraints; +/// 3. if all is good, recursivey matches the nested patterns; +/// 4. if all nested match then the single instruction matches too and is +/// appended to the list of matches; +/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will +/// want to traverse in post-order DFS to avoid invalidating iterators. +void NestedPattern::matchOne(Instruction *elem) { + if (storage->skip == elem) { + return; + } + // Structural filter + if (elem->getKind() != getKind()) { + return; + } + // Local custom filter function + if (!getFilterFunction()(*elem)) { + return; + } + + SmallVector<NestedMatch::EntryType, 8> nestedEntries; + for (auto c : getNestedPatterns()) { + /// We create a new nestedPattern here because a matcher holds its + /// results. So we concretely need multiple copies of a given matcher, one + /// for each matching result. + NestedPattern nestedPattern(c); + // Skip elem in the walk immediately following. Without this we would + // essentially need to reimplement walkPostOrder here. + nestedPattern.storage->skip = elem; + nestedPattern.walkPostOrder(elem); + if (!nestedPattern.matches) { + return; + } + for (auto m : nestedPattern.matches) { + nestedEntries.push_back(m); + } + } + + SmallVector<NestedMatch::EntryType, 8> newEntries( + matches.storage->matches.begin(), matches.storage->matches.end()); + newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries))); + matches = NestedMatch::build(newEntries); +} + +llvm::BumpPtrAllocator *&NestedPattern::allocator() { + static thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; +} + +NestedPattern::NestedPattern(Instruction::Kind k, + ArrayRef<NestedPattern> nested, + FilterFunctionType filter) + : storage(allocator()->Allocate<NestedPatternStorage>()), + matches(NestedMatch::build({})) { + auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size()); + std::uninitialized_copy(nested.begin(), nested.end(), newChildren); + // Initialize with placement new. + new (storage) NestedPatternStorage( + k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter, + nullptr /* skip */); +} + +Instruction::Kind NestedPattern::getKind() { return storage->kind; } + +ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() { + return storage->nestedPatterns; +} + +FilterFunctionType NestedPattern::getFilterFunction() { + return storage->filter; +} + +namespace mlir { +namespace matcher { + +NestedPattern Op(FilterFunctionType filter) { + return NestedPattern(Instruction::Kind::OperationInst, {}, filter); +} + +NestedPattern If(NestedPattern child) { + return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); +} +NestedPattern If(FilterFunctionType filter, NestedPattern child) { + return NestedPattern(Instruction::Kind::If, child, filter); +} +NestedPattern If(ArrayRef<NestedPattern> nested) { + return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); +} +NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { + return NestedPattern(Instruction::Kind::If, nested, filter); +} + +NestedPattern For(NestedPattern child) { + return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); +} +NestedPattern For(FilterFunctionType filter, NestedPattern child) { + return NestedPattern(Instruction::Kind::For, child, filter); +} +NestedPattern For(ArrayRef<NestedPattern> nested) { + return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); +} +NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { + return NestedPattern(Instruction::Kind::For, nested, filter); +} + +// TODO(ntv): parallel annotation on loops. +bool isParallelLoop(const Instruction &inst) { + const auto *loop = cast<ForInst>(&inst); + return (void *)loop || true; // loop->isParallel(); +}; + +// TODO(ntv): reduction annotation on loops. +bool isReductionLoop(const Instruction &inst) { + const auto *loop = cast<ForInst>(&inst); + return (void *)loop || true; // loop->isReduction(); +}; + +bool isLoadOrStore(const Instruction &inst) { + const auto *opInst = dyn_cast<OperationInst>(&inst); + return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()); +}; + +} // end namespace matcher +} // end namespace mlir diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 3dc81a0d793..4752928d062 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -22,7 +22,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/MLFunctionMatcher.h" +#include "mlir/Analysis/NestedMatcher.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -49,7 +49,7 @@ struct ComposeAffineMaps : public FunctionPass { PassResult runOnFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - MLFunctionMatcherContext MLContext; + NestedPatternContext MLContext; static char passID; }; @@ -74,8 +74,7 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) { auto apps = pattern.match(f); for (auto m : apps) { auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>(); - SmallVector<Value *, 8> operands(app->getOperands().begin(), - app->getOperands().end()); + SmallVector<Value *, 8> operands(app->getOperands()); FuncBuilder b(m.first); auto newApp = makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 19208d4c268..bb080949293 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -22,7 +22,7 @@ #include <type_traits> #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/MLFunctionMatcher.h" +#include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/EDSC/MLIREmitter.h" diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 8be97afeebd..7dd3cecdfed 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -23,7 +23,7 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/Analysis/MLFunctionMatcher.h" +#include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" @@ -200,7 +200,7 @@ struct MaterializeVectorsPass : public FunctionPass { PassResult runOnFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - MLFunctionMatcherContext mlContext; + NestedPatternContext mlContext; static char passID; }; diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index c08ffd4cd7d..0a199f008d6 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -20,7 +20,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/MLFunctionMatcher.h" +#include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/Builders.h" @@ -95,7 +95,7 @@ struct VectorizerTestPass : public FunctionPass { void testNormalizeMaps(Function *f); // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - MLFunctionMatcherContext MLContext; + NestedPatternContext MLContext; static char passID; }; @@ -153,7 +153,7 @@ static std::string toString(Instruction *inst) { return res; } -static MLFunctionMatches matchTestSlicingOps(Function *f) { +static NestedMatch matchTestSlicingOps(Function *f) { // Just use a custom op name for this test, it makes life easier. constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 29a97991d5e..e9b37fcc04c 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -21,7 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/Analysis/MLFunctionMatcher.h" +#include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" @@ -567,14 +567,14 @@ static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension); // Build a bunch of predetermined patterns that will be traversed in order. -// Due to the recursive nature of MLFunctionMatchers, this captures +// Due to the recursive nature of NestedPatterns, this captures // arbitrarily nested pairs of loops at any position in the tree. /// Note that this currently only matches 2 nested loops and will be extended. // TODO(ntv): support 3-D loop patterns with a common reduction loop that can // be matched to GEMMs. -static std::vector<MLFunctionMatcher> defaultPatterns() { +static std::vector<NestedPattern> defaultPatterns() { using matcher::For; - return std::vector<MLFunctionMatcher>{ + return std::vector<NestedPattern>{ // 3-D patterns For(isVectorizableLoopPtrFactory(2), For(isVectorizableLoopPtrFactory(1), @@ -627,7 +627,7 @@ static std::vector<MLFunctionMatcher> defaultPatterns() { /// Up to 3-D patterns are supported. /// If the command line argument requests a pattern of higher order, returns an /// empty pattern list which will conservatively result in no vectorization. -static std::vector<MLFunctionMatcher> makePatterns() { +static std::vector<NestedPattern> makePatterns() { using matcher::For; if (clFastestVaryingPattern.empty()) { return defaultPatterns(); @@ -644,7 +644,7 @@ static std::vector<MLFunctionMatcher> makePatterns() { For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1]), For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[2]))))}; default: - return std::vector<MLFunctionMatcher>(); + return std::vector<NestedPattern>(); } } @@ -656,7 +656,7 @@ struct Vectorize : public FunctionPass { PassResult runOnFunction(Function *f) override; // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - MLFunctionMatcherContext MLContext; + NestedPatternContext MLContext; static char passID; }; @@ -703,8 +703,8 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, /// 3. account for impact of vectorization on maximal loop fusion. /// Then we can quantify the above to build a cost model and search over /// strategies. -static bool analyzeProfitability(MLFunctionMatches matches, - unsigned depthInPattern, unsigned patternDepth, +static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern, + unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { auto *loop = cast<ForInst>(m.first); @@ -890,7 +890,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, return false; } -/// Returns a FilterFunctionType that can be used in MLFunctionMatcher to +/// Returns a FilterFunctionType that can be used in NestedPattern to /// match a loop whose underlying load/store accesses are all varying along the /// `fastestVaryingMemRefDimension`. /// TODO(ntv): In the future, allow more interesting mixed layout permutation @@ -906,16 +906,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { } /// Forward-declaration. -static bool vectorizeNonRoot(MLFunctionMatches matches, - VectorizationState *state); +static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state); /// Apply vectorization of `loop` according to `state`. This is only triggered /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. -static bool doVectorize(MLFunctionMatches::EntryType oneMatch, +static bool doVectorize(NestedMatch::EntryType oneMatch, VectorizationState *state) { ForInst *loop = cast<ForInst>(oneMatch.first); - MLFunctionMatches childrenMatches = oneMatch.second; + NestedMatch childrenMatches = oneMatch.second; // 1. DFS postorder recursion, if any of my children fails, I fail too. auto fail = vectorizeNonRoot(childrenMatches, state); @@ -949,8 +948,7 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch, /// Non-root pattern iterates over the matches at this level, calls doVectorize /// and exits early if anything below fails. -static bool vectorizeNonRoot(MLFunctionMatches matches, - VectorizationState *state) { +static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) { for (auto m : matches) { auto fail = doVectorize(m, state); if (fail) { @@ -1186,7 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) { /// The root match thus needs to maintain a clone for handling failure. /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. -static bool vectorizeRootMatches(MLFunctionMatches matches, +static bool vectorizeRootMatches(NestedMatch matches, VectorizationStrategy *strategy) { for (auto m : matches) { auto *loop = cast<ForInst>(m.first); |

