//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" namespace mlir { /// Underlying storage for NestedMatch. struct NestedMatchStorage { MutableArrayRef matches; }; /// Underlying storage for NestedPattern. struct NestedPatternStorage { NestedPatternStorage(Instruction::Kind k, ArrayRef c, FilterFunctionType filter, Instruction *skip) : kind(k), nestedPatterns(c), filter(filter), skip(skip) {} Instruction::Kind kind; ArrayRef nestedPatterns; FilterFunctionType filter; /// skip is needed so that we can implement match without switching on the /// type of the Instruction. /// The idea is that a NestedPattern first checks if it matches locally /// and then recursively applies its nested matchers to its elem->nested. /// Since we want to rely on the InstWalker impl rather than duplicate its /// the logic, we allow an off-by-one traversal to account for the fact that /// we write: /// /// void match(Instruction *elem) { /// for (auto &c : getNestedPatterns()) { /// NestedPattern childPattern(...); /// ^~~~ Needs off-by-one skip. /// Instruction *skip; }; } // end namespace mlir using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { static thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedMatch NestedMatch::build(ArrayRef elements) { auto *matches = allocator()->Allocate(elements.size()); std::uninitialized_copy(elements.begin(), elements.end(), matches); auto *storage = allocator()->Allocate(); new (storage) NestedMatchStorage(); storage->matches = MutableArrayRef(matches, elements.size()); auto *result = allocator()->Allocate(); new (result) NestedMatch(storage); return *result; } NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); } NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); } NestedMatch::EntryType &NestedMatch::front() { return *storage->matches.begin(); } NestedMatch::EntryType &NestedMatch::back() { return *(storage->matches.begin() + size() - 1); } /// Calls walk on `function`. NestedMatch NestedPattern::match(Function *function) { assert(!matches && "NestedPattern already matched!"); this->walkPostOrder(function); return matches; } /// Calls walk on `instruction`. NestedMatch NestedPattern::match(Instruction *instruction) { assert(!matches && "NestedPattern already matched!"); this->walkPostOrder(instruction); return matches; } unsigned NestedPattern::getDepth() { auto nested = getNestedPatterns(); if (nested.empty()) { return 1; } unsigned depth = 0; for (auto c : nested) { depth = std::max(depth, c.getDepth()); } return depth + 1; } /// Matches a single instruction in the following way: /// 1. checks the kind of instruction against the matcher, if different then /// there is no match; /// 2. calls the customizable filter function to refine the single instruction /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the nested patterns; /// 4. if all nested match then the single instruction matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. void NestedPattern::matchOne(Instruction *elem) { if (storage->skip == elem) { return; } // Structural filter if (elem->getKind() != getKind()) { return; } // Local custom filter function if (!getFilterFunction()(*elem)) { return; } SmallVector nestedEntries; for (auto c : getNestedPatterns()) { /// We create a new nestedPattern here because a matcher holds its /// results. So we concretely need multiple copies of a given matcher, one /// for each matching result. NestedPattern nestedPattern(c); // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. nestedPattern.storage->skip = elem; nestedPattern.walkPostOrder(elem); if (!nestedPattern.matches) { return; } for (auto m : nestedPattern.matches) { nestedEntries.push_back(m); } } SmallVector newEntries( matches.storage->matches.begin(), matches.storage->matches.end()); newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries))); matches = NestedMatch::build(newEntries); } llvm::BumpPtrAllocator *&NestedPattern::allocator() { static thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedPattern::NestedPattern(Instruction::Kind k, ArrayRef nested, FilterFunctionType filter) : storage(allocator()->Allocate()), matches(NestedMatch::build({})) { auto *newChildren = allocator()->Allocate(nested.size()); std::uninitialized_copy(nested.begin(), nested.end(), newChildren); // Initialize with placement new. new (storage) NestedPatternStorage( k, ArrayRef(newChildren, nested.size()), filter, nullptr /* skip */); } Instruction::Kind NestedPattern::getKind() { return storage->kind; } ArrayRef NestedPattern::getNestedPatterns() { return storage->nestedPatterns; } FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } static bool isAffineIfOp(const Instruction &inst) { return isa(inst) && cast(inst).isa(); } namespace mlir { namespace matcher { NestedPattern Op(FilterFunctionType filter) { return NestedPattern(Instruction::Kind::OperationInst, {}, filter); } NestedPattern If(NestedPattern child) { return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { return NestedPattern(Instruction::Kind::OperationInst, child, [filter](const Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } NestedPattern If(ArrayRef nested) { return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(Instruction::Kind::OperationInst, nested, [filter](const Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } NestedPattern For(NestedPattern child) { return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { return NestedPattern(Instruction::Kind::For, child, filter); } NestedPattern For(ArrayRef nested) { return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(Instruction::Kind::For, nested, filter); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { const auto *loop = cast(&inst); return (void *)loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { const auto *loop = cast(&inst); return (void *)loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { const auto *opInst = dyn_cast(&inst); return opInst && (opInst->isa() || opInst->isa()); }; } // end namespace matcher } // end namespace mlir