diff options
Diffstat (limited to 'mlir/lib/Analysis/NestedMatcher.cpp')
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 158 |
1 files changed, 42 insertions, 116 deletions
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 491a9bef1b9..43e3b332a58 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -20,93 +20,49 @@ #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" -namespace mlir { - -/// Underlying storage for NestedMatch. -struct NestedMatchStorage { - MutableArrayRef<NestedMatch::EntryType> matches; -}; - -/// Underlying storage for NestedPattern. -struct NestedPatternStorage { - NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c, - FilterFunctionType filter, Instruction *skip) - : kind(k), nestedPatterns(c), filter(filter), skip(skip) {} - - Instruction::Kind kind; - ArrayRef<NestedPattern> nestedPatterns; - FilterFunctionType filter; - /// skip is needed so that we can implement match without switching on the - /// type of the Instruction. - /// The idea is that a NestedPattern first checks if it matches locally - /// and then recursively applies its nested matchers to its elem->nested. - /// Since we want to rely on the InstWalker impl rather than duplicate its - /// the logic, we allow an off-by-one traversal to account for the fact that - /// we write: - /// - /// void match(Instruction *elem) { - /// for (auto &c : getNestedPatterns()) { - /// NestedPattern childPattern(...); - /// ^~~~ Needs off-by-one skip. - /// - Instruction *skip; -}; - -} // end namespace mlir - using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; + thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } -NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) { - auto *matches = - allocator()->Allocate<NestedMatch::EntryType>(elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), matches); - auto *storage = allocator()->Allocate<NestedMatchStorage>(); - new (storage) NestedMatchStorage(); - storage->matches = - MutableArrayRef<NestedMatch::EntryType>(matches, elements.size()); +NestedMatch NestedMatch::build(Instruction *instruction, + ArrayRef<NestedMatch> nestedMatches) { auto *result = allocator()->Allocate<NestedMatch>(); - new (result) NestedMatch(storage); + auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size()); + std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); + new (result) NestedMatch(); + result->matchedInstruction = instruction; + result->matchedChildren = + ArrayRef<NestedMatch>(children, nestedMatches.size()); return *result; } -NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); } -NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); } -NestedMatch::EntryType &NestedMatch::front() { - return *storage->matches.begin(); -} -NestedMatch::EntryType &NestedMatch::back() { - return *(storage->matches.begin() + size() - 1); -} - -/// Calls walk on `function`. -NestedMatch NestedPattern::match(Function *function) { - assert(!matches && "NestedPattern already matched!"); - this->walkPostOrder(function); - return matches; +llvm::BumpPtrAllocator *&NestedPattern::allocator() { + thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; } -/// Calls walk on `instruction`. -NestedMatch NestedPattern::match(Instruction *instruction) { - assert(!matches && "NestedPattern already matched!"); - this->walkPostOrder(instruction); - return matches; +NestedPattern::NestedPattern(Instruction::Kind k, + ArrayRef<NestedPattern> nested, + FilterFunctionType filter) + : kind(k), nestedPatterns(ArrayRef<NestedPattern>(nested)), filter(filter) { + auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); + std::uninitialized_copy(nested.begin(), nested.end(), newNested); + nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); } -unsigned NestedPattern::getDepth() { - auto nested = getNestedPatterns(); - if (nested.empty()) { +unsigned NestedPattern::getDepth() const { + if (nestedPatterns.empty()) { return 1; } unsigned depth = 0; - for (auto c : nested) { + for (auto &c : nestedPatterns) { depth = std::max(depth, c.getDepth()); } return depth + 1; @@ -122,69 +78,39 @@ unsigned NestedPattern::getDepth() { /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void NestedPattern::matchOne(Instruction *elem) { - if (storage->skip == elem) { +void NestedPattern::matchOne(Instruction *inst, + SmallVectorImpl<NestedMatch> *matches) { + if (skip == inst) { return; } // Structural filter - if (elem->getKind() != getKind()) { + if (inst->getKind() != kind) { return; } // Local custom filter function - if (!getFilterFunction()(*elem)) { + if (!filter(*inst)) { return; } - SmallVector<NestedMatch::EntryType, 8> nestedEntries; - for (auto c : getNestedPatterns()) { - /// We create a new nestedPattern here because a matcher holds its - /// results. So we concretely need multiple copies of a given matcher, one - /// for each matching result. - NestedPattern nestedPattern(c); + if (nestedPatterns.empty()) { + SmallVector<NestedMatch, 8> nestedMatches; + matches->push_back(NestedMatch::build(inst, nestedMatches)); + return; + } + // Take a copy of each nested pattern so we can match it. + for (auto nestedPattern : nestedPatterns) { + SmallVector<NestedMatch, 8> nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. - nestedPattern.storage->skip = elem; - nestedPattern.walkPostOrder(elem); - if (!nestedPattern.matches) { + nestedPattern.skip = inst; + nestedPattern.match(inst, &nestedMatches); + // If we could not match even one of the specified nestedPattern, early exit + // as this whole branch is not a match. + if (nestedMatches.empty()) { return; } - for (auto m : nestedPattern.matches) { - nestedEntries.push_back(m); - } + matches->push_back(NestedMatch::build(inst, nestedMatches)); } - - SmallVector<NestedMatch::EntryType, 8> newEntries( - matches.storage->matches.begin(), matches.storage->matches.end()); - newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries))); - matches = NestedMatch::build(newEntries); -} - -llvm::BumpPtrAllocator *&NestedPattern::allocator() { - static thread_local llvm::BumpPtrAllocator *allocator = nullptr; - return allocator; -} - -NestedPattern::NestedPattern(Instruction::Kind k, - ArrayRef<NestedPattern> nested, - FilterFunctionType filter) - : storage(allocator()->Allocate<NestedPatternStorage>()), - matches(NestedMatch::build({})) { - auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size()); - std::uninitialized_copy(nested.begin(), nested.end(), newChildren); - // Initialize with placement new. - new (storage) NestedPatternStorage( - k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter, - nullptr /* skip */); -} - -Instruction::Kind NestedPattern::getKind() { return storage->kind; } - -ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() { - return storage->nestedPatterns; -} - -FilterFunctionType NestedPattern::getFilterFunction() { - return storage->filter; } static bool isAffineIfOp(const Instruction &inst) { |

