summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/NestedMatcher.h (renamed from mlir/include/mlir/Analysis/MLFunctionMatcher.h)131
-rw-r--r--mlir/include/mlir/EDSC/MLIREmitter.h2
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/MLFunctionMatcher.cpp263
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp240
-rw-r--r--mlir/lib/Transforms/ComposeAffineMaps.cpp7
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp2
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp4
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp6
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp32
10 files changed, 334 insertions, 355 deletions
diff --git a/mlir/include/mlir/Analysis/MLFunctionMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h
index 5de1f6d729b..c205d55488e 100644
--- a/mlir/include/mlir/Analysis/MLFunctionMatcher.h
+++ b/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -1,4 +1,4 @@
-//===- MLFunctionMacher.h - Recursive matcher for MLFunction ----*- C++ -*-===//
+//===- NestedMacher.h - Nested matcher for MLFunction -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -24,22 +24,22 @@
namespace mlir {
-struct MLFunctionMatcherStorage;
-struct MLFunctionMatchesStorage;
+struct NestedPatternStorage;
+struct NestedMatchStorage;
class Instruction;
-/// An MLFunctionMatcher is a recursive matcher that captures nested patterns in
-/// an ML Function. It is used in conjunction with a scoped
-/// MLFunctionMatcherContext that handles the memory allocations efficiently.
+/// An NestedPattern captures nested patterns. It is used in conjunction with
+/// a scoped NestedPatternContext which is an llvm::BumPtrAllocator that
+/// handles memory allocations efficiently and avoids ownership issues.
///
-/// In order to use MLFunctionMatchers creates a scoped context and uses
-/// matchers. When the context goes out of scope, everything is freed.
+/// In order to use NestedPatterns, first create a scoped context. When the
+/// context goes out of scope, everything is freed.
/// This design simplifies the API by avoiding references to the context and
/// makes it clear that references to matchers must not escape.
///
/// Example:
/// {
-/// MLFunctionMatcherContext context;
+/// NestedPatternContext context;
/// auto gemmLike = Doall(Doall(Red(LoadStores())));
/// auto matches = gemmLike.match(f);
/// // do work on matches
@@ -51,15 +51,17 @@ class Instruction;
///
/// Implemented as a POD value-type with underlying storage pointer.
/// The underlying storage lives in a scoped bumper allocator whose lifetime
-/// is managed by an RAII MLFunctionMatcherContext.
-/// This should be used by value everywhere.
-struct MLFunctionMatches {
- using EntryType = std::pair<Instruction *, MLFunctionMatches>;
+/// is managed by an RAII NestedPatternContext.
+/// This is used by value everywhere.
+struct NestedMatch {
+ using EntryType = std::pair<Instruction *, NestedMatch>;
using iterator = EntryType *;
- MLFunctionMatches() : storage(nullptr) {}
+ static NestedMatch build(ArrayRef<NestedMatch::EntryType> elements = {});
+ NestedMatch(const NestedMatch &) = default;
+ NestedMatch &operator=(const NestedMatch &) = default;
- explicit operator bool() { return storage; }
+ explicit operator bool() { return !empty(); }
iterator begin();
iterator end();
@@ -68,20 +70,25 @@ struct MLFunctionMatches {
unsigned size() { return end() - begin(); }
unsigned empty() { return size() == 0; }
- /// Appends the pair <inst, children> to the current matches.
- void append(Instruction *inst, MLFunctionMatches children);
-
private:
- friend class MLFunctionMatcher;
- friend class MLFunctionMatcherContext;
+ friend class NestedPattern;
+ friend class NestedPatternContext;
+ friend class NestedMatchStorage;
- /// Underlying global bump allocator managed by an MLFunctionMatcherContext.
+ /// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
- MLFunctionMatchesStorage *storage;
+ NestedMatch(NestedMatchStorage *storage) : storage(storage){};
+
+ /// Copy the specified array of elements into memory managed by our bump
+ /// pointer allocator. The elements are all PODs by constructions.
+ static NestedMatch copyInto(ArrayRef<NestedMatch::EntryType> elements);
+
+ /// POD payload.
+ NestedMatchStorage *storage;
};
-/// A MLFunctionMatcher is a special type of InstWalker that:
+/// A NestedPattern is a special type of InstWalker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
@@ -89,78 +96,76 @@ private:
///
/// Implemented as a POD value-type with underlying storage pointer.
/// The underlying storage lives in a scoped bumper allocator whose lifetime
-/// is managed by an RAII MLFunctionMatcherContext.
+/// is managed by an RAII NestedPatternContext.
/// This should be used by value everywhere.
using FilterFunctionType = std::function<bool(const Instruction &)>;
static bool defaultFilterFunction(const Instruction &) { return true; };
-struct MLFunctionMatcher : public InstWalker<MLFunctionMatcher> {
- MLFunctionMatcher(Instruction::Kind k, MLFunctionMatcher child,
- FilterFunctionType filter = defaultFilterFunction);
- MLFunctionMatcher(Instruction::Kind k,
- MutableArrayRef<MLFunctionMatcher> children,
- FilterFunctionType filter = defaultFilterFunction);
+struct NestedPattern : public InstWalker<NestedPattern> {
+ NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested,
+ FilterFunctionType filter = defaultFilterFunction);
+ NestedPattern(const NestedPattern &) = default;
+ NestedPattern &operator=(const NestedPattern &) = default;
/// Returns all the matches in `function`.
- MLFunctionMatches match(Function *function);
+ NestedMatch match(Function *function);
/// Returns all the matches nested under `instruction`.
- MLFunctionMatches match(Instruction *instruction);
+ NestedMatch match(Instruction *instruction);
unsigned getDepth();
private:
- friend class MLFunctionMatcherContext;
- friend InstWalker<MLFunctionMatcher>;
+ friend class NestedPatternContext;
+ friend InstWalker<NestedPattern>;
+
+ /// Underlying global bump allocator managed by a NestedPatternContext.
+ static llvm::BumpPtrAllocator *&allocator();
Instruction::Kind getKind();
- MutableArrayRef<MLFunctionMatcher> getChildrenMLFunctionMatchers();
+ ArrayRef<NestedPattern> getNestedPatterns();
FilterFunctionType getFilterFunction();
- MLFunctionMatcher forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
- Instruction *inst);
-
void matchOne(Instruction *elem);
void visitForInst(ForInst *forInst) { matchOne(forInst); }
void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
- /// Underlying global bump allocator managed by an MLFunctionMatcherContext.
- static llvm::BumpPtrAllocator *&allocator();
-
- MLFunctionMatcherStorage *storage;
+ /// POD paylod.
+ /// Storage for the PatternMatcher.
+ NestedPatternStorage *storage;
// By-value POD wrapper to underlying storage pointer.
- MLFunctionMatches matches;
+ NestedMatch matches;
};
/// RAII structure to transparently manage the bump allocator for
-/// MLFunctionMatcher and MLFunctionMatches classes.
-struct MLFunctionMatcherContext {
- MLFunctionMatcherContext() {
- MLFunctionMatcher::allocator() = &allocator;
- MLFunctionMatches::allocator() = &allocator;
+/// NestedPattern and NestedMatch classes.
+struct NestedPatternContext {
+ NestedPatternContext() {
+ NestedPattern::allocator() = &allocator;
+ NestedMatch::allocator() = &allocator;
}
- ~MLFunctionMatcherContext() {
- MLFunctionMatcher::allocator() = nullptr;
- MLFunctionMatches::allocator() = nullptr;
+ ~NestedPatternContext() {
+ NestedPattern::allocator() = nullptr;
+ NestedMatch::allocator() = nullptr;
}
llvm::BumpPtrAllocator allocator;
};
namespace matcher {
-// Syntactic sugar MLFunctionMatcher builder functions.
-MLFunctionMatcher Op(FilterFunctionType filter = defaultFilterFunction);
-MLFunctionMatcher If(MLFunctionMatcher child);
-MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child);
-MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children = {});
-MLFunctionMatcher If(FilterFunctionType filter,
- MutableArrayRef<MLFunctionMatcher> children = {});
-MLFunctionMatcher For(MLFunctionMatcher child);
-MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child);
-MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children = {});
-MLFunctionMatcher For(FilterFunctionType filter,
- MutableArrayRef<MLFunctionMatcher> children = {});
+// Syntactic sugar NestedPattern builder functions.
+NestedPattern Op(FilterFunctionType filter = defaultFilterFunction);
+NestedPattern If(NestedPattern child);
+NestedPattern If(FilterFunctionType filter, NestedPattern child);
+NestedPattern If(ArrayRef<NestedPattern> nested = {});
+NestedPattern If(FilterFunctionType filter,
+ ArrayRef<NestedPattern> nested = {});
+NestedPattern For(NestedPattern child);
+NestedPattern For(FilterFunctionType filter, NestedPattern child);
+NestedPattern For(ArrayRef<NestedPattern> nested = {});
+NestedPattern For(FilterFunctionType filter,
+ ArrayRef<NestedPattern> nested = {});
bool isParallelLoop(const Instruction &inst);
bool isReductionLoop(const Instruction &inst);
diff --git a/mlir/include/mlir/EDSC/MLIREmitter.h b/mlir/include/mlir/EDSC/MLIREmitter.h
index a696914eee2..fbd5a544a30 100644
--- a/mlir/include/mlir/EDSC/MLIREmitter.h
+++ b/mlir/include/mlir/EDSC/MLIREmitter.h
@@ -23,7 +23,7 @@
// generally designed to be automatically generated from various IR dialects in
// the future.
// The implementation is supported by a lightweight by-value abstraction on a
-// scoped BumpAllocator with similarities to AffineExpr and MLFunctionMatcher.
+// scoped BumpAllocator with similarities to AffineExpr and NestedPattern.
//
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index b66b665c563..b154ebab105 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -23,7 +23,7 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
-#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp
deleted file mode 100644
index f2bbcd2a566..00000000000
--- a/mlir/lib/Analysis/MLFunctionMatcher.cpp
+++ /dev/null
@@ -1,263 +0,0 @@
-//===- MLFunctionMatcher.cpp - MLFunctionMatcher Impl ----------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/Analysis/MLFunctionMatcher.h"
-#include "mlir/StandardOps/StandardOps.h"
-
-#include "llvm/Support/Allocator.h"
-
-namespace mlir {
-
-/// Underlying storage for MLFunctionMatches.
-struct MLFunctionMatchesStorage {
- MLFunctionMatchesStorage(MLFunctionMatches::EntryType e) : matches({e}) {}
-
- SmallVector<MLFunctionMatches::EntryType, 8> matches;
-};
-
-/// Underlying storage for MLFunctionMatcher.
-struct MLFunctionMatcherStorage {
- MLFunctionMatcherStorage(Instruction::Kind k,
- MutableArrayRef<MLFunctionMatcher> c,
- FilterFunctionType filter, Instruction *skip)
- : kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter),
- skip(skip) {}
-
- Instruction::Kind kind;
- SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
- FilterFunctionType filter;
- /// skip is needed so that we can implement match without switching on the
- /// type of the Instruction.
- /// The idea is that a MLFunctionMatcher first checks if it matches locally
- /// and then recursively applies its children matchers to its elem->children.
- /// Since we want to rely on the InstWalker impl rather than duplicate its
- /// the logic, we allow an off-by-one traversal to account for the fact that
- /// we write:
- ///
- /// void match(Instruction *elem) {
- /// for (auto &c : getChildrenMLFunctionMatchers()) {
- /// MLFunctionMatcher childMLFunctionMatcher(...);
- /// ^~~~ Needs off-by-one skip.
- ///
- Instruction *skip;
-};
-
-} // end namespace mlir
-
-using namespace mlir;
-
-llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
- return allocator;
-}
-
-void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) {
- if (!storage) {
- storage = allocator()->Allocate<MLFunctionMatchesStorage>();
- new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children));
- } else {
- storage->matches.push_back(std::make_pair(inst, children));
- }
-}
-MLFunctionMatches::iterator MLFunctionMatches::begin() {
- return storage ? storage->matches.begin() : nullptr;
-}
-MLFunctionMatches::iterator MLFunctionMatches::end() {
- return storage ? storage->matches.end() : nullptr;
-}
-MLFunctionMatches::EntryType &MLFunctionMatches::front() {
- assert(storage && "null storage");
- return *storage->matches.begin();
-}
-MLFunctionMatches::EntryType &MLFunctionMatches::back() {
- assert(storage && "null storage");
- return *(storage->matches.begin() + size() - 1);
-}
-/// Return the combination of multiple MLFunctionMatches as a new object.
-static MLFunctionMatches combine(ArrayRef<MLFunctionMatches> matches) {
- MLFunctionMatches res;
- for (auto s : matches) {
- for (auto ss : s) {
- res.append(ss.first, ss.second);
- }
- }
- return res;
-}
-
-/// Calls walk on `function`.
-MLFunctionMatches MLFunctionMatcher::match(Function *function) {
- assert(!matches && "MLFunctionMatcher already matched!");
- this->walkPostOrder(function);
- return matches;
-}
-
-/// Calls walk on `instruction`.
-MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) {
- assert(!matches && "MLFunctionMatcher already matched!");
- this->walkPostOrder(instruction);
- return matches;
-}
-
-unsigned MLFunctionMatcher::getDepth() {
- auto children = getChildrenMLFunctionMatchers();
- if (children.empty()) {
- return 1;
- }
- unsigned depth = 0;
- for (auto &c : children) {
- depth = std::max(depth, c.getDepth());
- }
- return depth + 1;
-}
-
-/// Matches a single instruction in the following way:
-/// 1. checks the kind of instruction against the matcher, if different then
-/// there is no match;
-/// 2. calls the customizable filter function to refine the single instruction
-/// match with extra semantic constraints;
-/// 3. if all is good, recursivey matches the children patterns;
-/// 4. if all children match then the single instruction matches too and is
-/// appended to the list of matches;
-/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
-/// want to traverse in post-order DFS to avoid invalidating iterators.
-void MLFunctionMatcher::matchOne(Instruction *elem) {
- if (storage->skip == elem) {
- return;
- }
- // Structural filter
- if (elem->getKind() != getKind()) {
- return;
- }
- // Local custom filter function
- if (!getFilterFunction()(*elem)) {
- return;
- }
- SmallVector<MLFunctionMatches, 8> childrenMLFunctionMatches;
- for (auto &c : getChildrenMLFunctionMatchers()) {
- /// We create a new childMLFunctionMatcher here because a matcher holds its
- /// results. So we concretely need multiple copies of a given matcher, one
- /// for each matching result.
- MLFunctionMatcher childMLFunctionMatcher = forkMLFunctionMatcherAt(c, elem);
- childMLFunctionMatcher.walkPostOrder(elem);
- if (!childMLFunctionMatcher.matches) {
- return;
- }
- childrenMLFunctionMatches.push_back(childMLFunctionMatcher.matches);
- }
- matches.append(elem, combine(childrenMLFunctionMatches));
-}
-
-llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
- return allocator;
-}
-
-MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k,
- MLFunctionMatcher child,
- FilterFunctionType filter)
- : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
- // Initialize with placement new.
- new (storage)
- MLFunctionMatcherStorage(k, {child}, filter, nullptr /* skip */);
-}
-
-MLFunctionMatcher::MLFunctionMatcher(
- Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children,
- FilterFunctionType filter)
- : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
- // Initialize with placement new.
- new (storage)
- MLFunctionMatcherStorage(k, children, filter, nullptr /* skip */);
-}
-
-MLFunctionMatcher
-MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
- Instruction *inst) {
- MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
- tmpl.getFilterFunction());
- res.storage->skip = inst;
- return res;
-}
-
-Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; }
-
-MutableArrayRef<MLFunctionMatcher>
-MLFunctionMatcher::getChildrenMLFunctionMatchers() {
- return storage->childrenMLFunctionMatchers;
-}
-
-FilterFunctionType MLFunctionMatcher::getFilterFunction() {
- return storage->filter;
-}
-
-namespace mlir {
-namespace matcher {
-
-MLFunctionMatcher Op(FilterFunctionType filter) {
- return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter);
-}
-
-MLFunctionMatcher If(MLFunctionMatcher child) {
- return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction);
-}
-MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
- return MLFunctionMatcher(Instruction::Kind::If, child, filter);
-}
-MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Instruction::Kind::If, children,
- defaultFilterFunction);
-}
-MLFunctionMatcher If(FilterFunctionType filter,
- MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Instruction::Kind::If, children, filter);
-}
-
-MLFunctionMatcher For(MLFunctionMatcher child) {
- return MLFunctionMatcher(Instruction::Kind::For, child,
- defaultFilterFunction);
-}
-MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
- return MLFunctionMatcher(Instruction::Kind::For, child, filter);
-}
-MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Instruction::Kind::For, children,
- defaultFilterFunction);
-}
-MLFunctionMatcher For(FilterFunctionType filter,
- MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Instruction::Kind::For, children, filter);
-}
-
-// TODO(ntv): parallel annotation on loops.
-bool isParallelLoop(const Instruction &inst) {
- const auto *loop = cast<ForInst>(&inst);
- return (void *)loop || true; // loop->isParallel();
-};
-
-// TODO(ntv): reduction annotation on loops.
-bool isReductionLoop(const Instruction &inst) {
- const auto *loop = cast<ForInst>(&inst);
- return (void *)loop || true; // loop->isReduction();
-};
-
-bool isLoadOrStore(const Instruction &inst) {
- const auto *opInst = dyn_cast<OperationInst>(&inst);
- return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
-};
-
-} // end namespace matcher
-} // end namespace mlir
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
new file mode 100644
index 00000000000..4f32e9b22f4
--- /dev/null
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -0,0 +1,240 @@
+//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/StandardOps/StandardOps.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+/// Underlying storage for NestedMatch.
+struct NestedMatchStorage {
+ MutableArrayRef<NestedMatch::EntryType> matches;
+};
+
+/// Underlying storage for NestedPattern.
+struct NestedPatternStorage {
+ NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c,
+ FilterFunctionType filter, Instruction *skip)
+ : kind(k), nestedPatterns(c), filter(filter), skip(skip) {}
+
+ Instruction::Kind kind;
+ ArrayRef<NestedPattern> nestedPatterns;
+ FilterFunctionType filter;
+ /// skip is needed so that we can implement match without switching on the
+ /// type of the Instruction.
+ /// The idea is that a NestedPattern first checks if it matches locally
+ /// and then recursively applies its nested matchers to its elem->nested.
+ /// Since we want to rely on the InstWalker impl rather than duplicate its
+ /// the logic, we allow an off-by-one traversal to account for the fact that
+ /// we write:
+ ///
+ /// void match(Instruction *elem) {
+ /// for (auto &c : getNestedPatterns()) {
+ /// NestedPattern childPattern(...);
+ /// ^~~~ Needs off-by-one skip.
+ ///
+ Instruction *skip;
+};
+
+} // end namespace mlir
+
+using namespace mlir;
+
+llvm::BumpPtrAllocator *&NestedMatch::allocator() {
+ static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
+}
+
+NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) {
+ auto *matches =
+ allocator()->Allocate<NestedMatch::EntryType>(elements.size());
+ std::uninitialized_copy(elements.begin(), elements.end(), matches);
+ auto *storage = allocator()->Allocate<NestedMatchStorage>();
+ new (storage) NestedMatchStorage();
+ storage->matches =
+ MutableArrayRef<NestedMatch::EntryType>(matches, elements.size());
+ auto *result = allocator()->Allocate<NestedMatch>();
+ new (result) NestedMatch(storage);
+ return *result;
+}
+
+NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); }
+NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); }
+NestedMatch::EntryType &NestedMatch::front() {
+ return *storage->matches.begin();
+}
+NestedMatch::EntryType &NestedMatch::back() {
+ return *(storage->matches.begin() + size() - 1);
+}
+
+/// Calls walk on `function`.
+NestedMatch NestedPattern::match(Function *function) {
+ assert(!matches && "NestedPattern already matched!");
+ this->walkPostOrder(function);
+ return matches;
+}
+
+/// Calls walk on `instruction`.
+NestedMatch NestedPattern::match(Instruction *instruction) {
+ assert(!matches && "NestedPattern already matched!");
+ this->walkPostOrder(instruction);
+ return matches;
+}
+
+unsigned NestedPattern::getDepth() {
+ auto nested = getNestedPatterns();
+ if (nested.empty()) {
+ return 1;
+ }
+ unsigned depth = 0;
+ for (auto c : nested) {
+ depth = std::max(depth, c.getDepth());
+ }
+ return depth + 1;
+}
+
+/// Matches a single instruction in the following way:
+/// 1. checks the kind of instruction against the matcher, if different then
+/// there is no match;
+/// 2. calls the customizable filter function to refine the single instruction
+/// match with extra semantic constraints;
+/// 3. if all is good, recursivey matches the nested patterns;
+/// 4. if all nested match then the single instruction matches too and is
+/// appended to the list of matches;
+/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
+/// want to traverse in post-order DFS to avoid invalidating iterators.
+void NestedPattern::matchOne(Instruction *elem) {
+ if (storage->skip == elem) {
+ return;
+ }
+ // Structural filter
+ if (elem->getKind() != getKind()) {
+ return;
+ }
+ // Local custom filter function
+ if (!getFilterFunction()(*elem)) {
+ return;
+ }
+
+ SmallVector<NestedMatch::EntryType, 8> nestedEntries;
+ for (auto c : getNestedPatterns()) {
+ /// We create a new nestedPattern here because a matcher holds its
+ /// results. So we concretely need multiple copies of a given matcher, one
+ /// for each matching result.
+ NestedPattern nestedPattern(c);
+ // Skip elem in the walk immediately following. Without this we would
+ // essentially need to reimplement walkPostOrder here.
+ nestedPattern.storage->skip = elem;
+ nestedPattern.walkPostOrder(elem);
+ if (!nestedPattern.matches) {
+ return;
+ }
+ for (auto m : nestedPattern.matches) {
+ nestedEntries.push_back(m);
+ }
+ }
+
+ SmallVector<NestedMatch::EntryType, 8> newEntries(
+ matches.storage->matches.begin(), matches.storage->matches.end());
+ newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries)));
+ matches = NestedMatch::build(newEntries);
+}
+
+llvm::BumpPtrAllocator *&NestedPattern::allocator() {
+ static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
+}
+
+NestedPattern::NestedPattern(Instruction::Kind k,
+ ArrayRef<NestedPattern> nested,
+ FilterFunctionType filter)
+ : storage(allocator()->Allocate<NestedPatternStorage>()),
+ matches(NestedMatch::build({})) {
+ auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size());
+ std::uninitialized_copy(nested.begin(), nested.end(), newChildren);
+ // Initialize with placement new.
+ new (storage) NestedPatternStorage(
+ k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter,
+ nullptr /* skip */);
+}
+
+Instruction::Kind NestedPattern::getKind() { return storage->kind; }
+
+ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() {
+ return storage->nestedPatterns;
+}
+
+FilterFunctionType NestedPattern::getFilterFunction() {
+ return storage->filter;
+}
+
+namespace mlir {
+namespace matcher {
+
+NestedPattern Op(FilterFunctionType filter) {
+ return NestedPattern(Instruction::Kind::OperationInst, {}, filter);
+}
+
+NestedPattern If(NestedPattern child) {
+ return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
+}
+NestedPattern If(FilterFunctionType filter, NestedPattern child) {
+ return NestedPattern(Instruction::Kind::If, child, filter);
+}
+NestedPattern If(ArrayRef<NestedPattern> nested) {
+ return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
+}
+NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
+ return NestedPattern(Instruction::Kind::If, nested, filter);
+}
+
+NestedPattern For(NestedPattern child) {
+ return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction);
+}
+NestedPattern For(FilterFunctionType filter, NestedPattern child) {
+ return NestedPattern(Instruction::Kind::For, child, filter);
+}
+NestedPattern For(ArrayRef<NestedPattern> nested) {
+ return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction);
+}
+NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
+ return NestedPattern(Instruction::Kind::For, nested, filter);
+}
+
+// TODO(ntv): parallel annotation on loops.
+bool isParallelLoop(const Instruction &inst) {
+ const auto *loop = cast<ForInst>(&inst);
+ return (void *)loop || true; // loop->isParallel();
+};
+
+// TODO(ntv): reduction annotation on loops.
+bool isReductionLoop(const Instruction &inst) {
+ const auto *loop = cast<ForInst>(&inst);
+ return (void *)loop || true; // loop->isReduction();
+};
+
+bool isLoadOrStore(const Instruction &inst) {
+ const auto *opInst = dyn_cast<OperationInst>(&inst);
+ return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
+};
+
+} // end namespace matcher
+} // end namespace mlir
diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp
index 3dc81a0d793..4752928d062 100644
--- a/mlir/lib/Transforms/ComposeAffineMaps.cpp
+++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp
@@ -22,7 +22,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -49,7 +49,7 @@ struct ComposeAffineMaps : public FunctionPass {
PassResult runOnFunction(Function *f) override;
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- MLFunctionMatcherContext MLContext;
+ NestedPatternContext MLContext;
static char passID;
};
@@ -74,8 +74,7 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) {
auto apps = pattern.match(f);
for (auto m : apps) {
auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>();
- SmallVector<Value *, 8> operands(app->getOperands().begin(),
- app->getOperands().end());
+ SmallVector<Value *, 8> operands(app->getOperands());
FuncBuilder b(m.first);
auto newApp = makeComposedAffineApply(&b, app->getLoc(),
app->getAffineMap(), operands);
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index 19208d4c268..bb080949293 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -22,7 +22,7 @@
#include <type_traits>
#include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/EDSC/MLIREmitter.h"
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 8be97afeebd..7dd3cecdfed 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -23,7 +23,7 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/LoopAnalysis.h"
-#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Analysis/VectorAnalysis.h"
@@ -200,7 +200,7 @@ struct MaterializeVectorsPass : public FunctionPass {
PassResult runOnFunction(Function *f) override;
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- MLFunctionMatcherContext mlContext;
+ NestedPatternContext mlContext;
static char passID;
};
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index c08ffd4cd7d..0a199f008d6 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -20,7 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/Builders.h"
@@ -95,7 +95,7 @@ struct VectorizerTestPass : public FunctionPass {
void testNormalizeMaps(Function *f);
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- MLFunctionMatcherContext MLContext;
+ NestedPatternContext MLContext;
static char passID;
};
@@ -153,7 +153,7 @@ static std::string toString(Instruction *inst) {
return res;
}
-static MLFunctionMatches matchTestSlicingOps(Function *f) {
+static NestedMatch matchTestSlicingOps(Function *f) {
// Just use a custom op name for this test, it makes life easier.
constexpr auto kTestSlicingOpName = "slicing-test-op";
using functional::map;
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index 29a97991d5e..e9b37fcc04c 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -21,7 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/LoopAnalysis.h"
-#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
@@ -567,14 +567,14 @@ static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension);
// Build a bunch of predetermined patterns that will be traversed in order.
-// Due to the recursive nature of MLFunctionMatchers, this captures
+// Due to the recursive nature of NestedPatterns, this captures
// arbitrarily nested pairs of loops at any position in the tree.
/// Note that this currently only matches 2 nested loops and will be extended.
// TODO(ntv): support 3-D loop patterns with a common reduction loop that can
// be matched to GEMMs.
-static std::vector<MLFunctionMatcher> defaultPatterns() {
+static std::vector<NestedPattern> defaultPatterns() {
using matcher::For;
- return std::vector<MLFunctionMatcher>{
+ return std::vector<NestedPattern>{
// 3-D patterns
For(isVectorizableLoopPtrFactory(2),
For(isVectorizableLoopPtrFactory(1),
@@ -627,7 +627,7 @@ static std::vector<MLFunctionMatcher> defaultPatterns() {
/// Up to 3-D patterns are supported.
/// If the command line argument requests a pattern of higher order, returns an
/// empty pattern list which will conservatively result in no vectorization.
-static std::vector<MLFunctionMatcher> makePatterns() {
+static std::vector<NestedPattern> makePatterns() {
using matcher::For;
if (clFastestVaryingPattern.empty()) {
return defaultPatterns();
@@ -644,7 +644,7 @@ static std::vector<MLFunctionMatcher> makePatterns() {
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[1]),
For(isVectorizableLoopPtrFactory(clFastestVaryingPattern[2]))))};
default:
- return std::vector<MLFunctionMatcher>();
+ return std::vector<NestedPattern>();
}
}
@@ -656,7 +656,7 @@ struct Vectorize : public FunctionPass {
PassResult runOnFunction(Function *f) override;
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- MLFunctionMatcherContext MLContext;
+ NestedPatternContext MLContext;
static char passID;
};
@@ -703,8 +703,8 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
/// 3. account for impact of vectorization on maximal loop fusion.
/// Then we can quantify the above to build a cost model and search over
/// strategies.
-static bool analyzeProfitability(MLFunctionMatches matches,
- unsigned depthInPattern, unsigned patternDepth,
+static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern,
+ unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
auto *loop = cast<ForInst>(m.first);
@@ -890,7 +890,7 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
return false;
}
-/// Returns a FilterFunctionType that can be used in MLFunctionMatcher to
+/// Returns a FilterFunctionType that can be used in NestedPattern to
/// match a loop whose underlying load/store accesses are all varying along the
/// `fastestVaryingMemRefDimension`.
/// TODO(ntv): In the future, allow more interesting mixed layout permutation
@@ -906,16 +906,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
}
/// Forward-declaration.
-static bool vectorizeNonRoot(MLFunctionMatches matches,
- VectorizationState *state);
+static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state);
/// Apply vectorization of `loop` according to `state`. This is only triggered
/// if all vectorizations in `childrenMatches` have already succeeded
/// recursively in DFS post-order.
-static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
+static bool doVectorize(NestedMatch::EntryType oneMatch,
VectorizationState *state) {
ForInst *loop = cast<ForInst>(oneMatch.first);
- MLFunctionMatches childrenMatches = oneMatch.second;
+ NestedMatch childrenMatches = oneMatch.second;
// 1. DFS postorder recursion, if any of my children fails, I fail too.
auto fail = vectorizeNonRoot(childrenMatches, state);
@@ -949,8 +948,7 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
/// Non-root pattern iterates over the matches at this level, calls doVectorize
/// and exits early if anything below fails.
-static bool vectorizeNonRoot(MLFunctionMatches matches,
- VectorizationState *state) {
+static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) {
for (auto m : matches) {
auto fail = doVectorize(m, state);
if (fail) {
@@ -1186,7 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) {
/// The root match thus needs to maintain a clone for handling failure.
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
-static bool vectorizeRootMatches(MLFunctionMatches matches,
+static bool vectorizeRootMatches(NestedMatch matches,
VectorizationStrategy *strategy) {
for (auto m : matches) {
auto *loop = cast<ForInst>(m.first);
OpenPOWER on IntegriCloud