summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis/NestedMatcher.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/NestedMatcher.cpp')
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp158
1 files changed, 42 insertions, 116 deletions
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 491a9bef1b9..43e3b332a58 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -20,93 +20,49 @@
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/raw_ostream.h"
-namespace mlir {
-
-/// Underlying storage for NestedMatch.
-struct NestedMatchStorage {
- MutableArrayRef<NestedMatch::EntryType> matches;
-};
-
-/// Underlying storage for NestedPattern.
-struct NestedPatternStorage {
- NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c,
- FilterFunctionType filter, Instruction *skip)
- : kind(k), nestedPatterns(c), filter(filter), skip(skip) {}
-
- Instruction::Kind kind;
- ArrayRef<NestedPattern> nestedPatterns;
- FilterFunctionType filter;
- /// skip is needed so that we can implement match without switching on the
- /// type of the Instruction.
- /// The idea is that a NestedPattern first checks if it matches locally
- /// and then recursively applies its nested matchers to its elem->nested.
- /// Since we want to rely on the InstWalker impl rather than duplicate its
- /// the logic, we allow an off-by-one traversal to account for the fact that
- /// we write:
- ///
- /// void match(Instruction *elem) {
- /// for (auto &c : getNestedPatterns()) {
- /// NestedPattern childPattern(...);
- /// ^~~~ Needs off-by-one skip.
- ///
- Instruction *skip;
-};
-
-} // end namespace mlir
-
using namespace mlir;
llvm::BumpPtrAllocator *&NestedMatch::allocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ thread_local llvm::BumpPtrAllocator *allocator = nullptr;
return allocator;
}
-NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) {
- auto *matches =
- allocator()->Allocate<NestedMatch::EntryType>(elements.size());
- std::uninitialized_copy(elements.begin(), elements.end(), matches);
- auto *storage = allocator()->Allocate<NestedMatchStorage>();
- new (storage) NestedMatchStorage();
- storage->matches =
- MutableArrayRef<NestedMatch::EntryType>(matches, elements.size());
+NestedMatch NestedMatch::build(Instruction *instruction,
+ ArrayRef<NestedMatch> nestedMatches) {
auto *result = allocator()->Allocate<NestedMatch>();
- new (result) NestedMatch(storage);
+ auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
+ std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
+ new (result) NestedMatch();
+ result->matchedInstruction = instruction;
+ result->matchedChildren =
+ ArrayRef<NestedMatch>(children, nestedMatches.size());
return *result;
}
-NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); }
-NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); }
-NestedMatch::EntryType &NestedMatch::front() {
- return *storage->matches.begin();
-}
-NestedMatch::EntryType &NestedMatch::back() {
- return *(storage->matches.begin() + size() - 1);
-}
-
-/// Calls walk on `function`.
-NestedMatch NestedPattern::match(Function *function) {
- assert(!matches && "NestedPattern already matched!");
- this->walkPostOrder(function);
- return matches;
+llvm::BumpPtrAllocator *&NestedPattern::allocator() {
+ thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
}
-/// Calls walk on `instruction`.
-NestedMatch NestedPattern::match(Instruction *instruction) {
- assert(!matches && "NestedPattern already matched!");
- this->walkPostOrder(instruction);
- return matches;
+NestedPattern::NestedPattern(Instruction::Kind k,
+ ArrayRef<NestedPattern> nested,
+ FilterFunctionType filter)
+ : kind(k), nestedPatterns(ArrayRef<NestedPattern>(nested)), filter(filter) {
+ auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
+ std::uninitialized_copy(nested.begin(), nested.end(), newNested);
+ nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
}
-unsigned NestedPattern::getDepth() {
- auto nested = getNestedPatterns();
- if (nested.empty()) {
+unsigned NestedPattern::getDepth() const {
+ if (nestedPatterns.empty()) {
return 1;
}
unsigned depth = 0;
- for (auto c : nested) {
+ for (auto &c : nestedPatterns) {
depth = std::max(depth, c.getDepth());
}
return depth + 1;
@@ -122,69 +78,39 @@ unsigned NestedPattern::getDepth() {
/// appended to the list of matches;
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
/// want to traverse in post-order DFS to avoid invalidating iterators.
-void NestedPattern::matchOne(Instruction *elem) {
- if (storage->skip == elem) {
+void NestedPattern::matchOne(Instruction *inst,
+ SmallVectorImpl<NestedMatch> *matches) {
+ if (skip == inst) {
return;
}
// Structural filter
- if (elem->getKind() != getKind()) {
+ if (inst->getKind() != kind) {
return;
}
// Local custom filter function
- if (!getFilterFunction()(*elem)) {
+ if (!filter(*inst)) {
return;
}
- SmallVector<NestedMatch::EntryType, 8> nestedEntries;
- for (auto c : getNestedPatterns()) {
- /// We create a new nestedPattern here because a matcher holds its
- /// results. So we concretely need multiple copies of a given matcher, one
- /// for each matching result.
- NestedPattern nestedPattern(c);
+ if (nestedPatterns.empty()) {
+ SmallVector<NestedMatch, 8> nestedMatches;
+ matches->push_back(NestedMatch::build(inst, nestedMatches));
+ return;
+ }
+ // Take a copy of each nested pattern so we can match it.
+ for (auto nestedPattern : nestedPatterns) {
+ SmallVector<NestedMatch, 8> nestedMatches;
// Skip elem in the walk immediately following. Without this we would
// essentially need to reimplement walkPostOrder here.
- nestedPattern.storage->skip = elem;
- nestedPattern.walkPostOrder(elem);
- if (!nestedPattern.matches) {
+ nestedPattern.skip = inst;
+ nestedPattern.match(inst, &nestedMatches);
+ // If we could not match even one of the specified nestedPattern, early exit
+ // as this whole branch is not a match.
+ if (nestedMatches.empty()) {
return;
}
- for (auto m : nestedPattern.matches) {
- nestedEntries.push_back(m);
- }
+ matches->push_back(NestedMatch::build(inst, nestedMatches));
}
-
- SmallVector<NestedMatch::EntryType, 8> newEntries(
- matches.storage->matches.begin(), matches.storage->matches.end());
- newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries)));
- matches = NestedMatch::build(newEntries);
-}
-
-llvm::BumpPtrAllocator *&NestedPattern::allocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
- return allocator;
-}
-
-NestedPattern::NestedPattern(Instruction::Kind k,
- ArrayRef<NestedPattern> nested,
- FilterFunctionType filter)
- : storage(allocator()->Allocate<NestedPatternStorage>()),
- matches(NestedMatch::build({})) {
- auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size());
- std::uninitialized_copy(nested.begin(), nested.end(), newChildren);
- // Initialize with placement new.
- new (storage) NestedPatternStorage(
- k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter,
- nullptr /* skip */);
-}
-
-Instruction::Kind NestedPattern::getKind() { return storage->kind; }
-
-ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() {
- return storage->nestedPatterns;
-}
-
-FilterFunctionType NestedPattern::getFilterFunction() {
- return storage->filter;
}
static bool isAffineIfOp(const Instruction &inst) {
OpenPOWER on IntegriCloud