//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedMatch NestedMatch::build(Instruction *instruction, ArrayRef nestedMatches) { auto *result = allocator()->Allocate(); auto *children = allocator()->Allocate(nestedMatches.size()); std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); new (result) NestedMatch(); result->matchedInstruction = instruction; result->matchedChildren = ArrayRef(children, nestedMatches.size()); return *result; } llvm::BumpPtrAllocator *&NestedPattern::allocator() { thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedPattern::NestedPattern(ArrayRef nested, FilterFunctionType filter) : nestedPatterns(), filter(filter), skip(nullptr) { if (!nested.empty()) { auto *newNested = allocator()->Allocate(nested.size()); std::uninitialized_copy(nested.begin(), nested.end(), newNested); nestedPatterns = ArrayRef(newNested, nested.size()); } } unsigned NestedPattern::getDepth() const { if (nestedPatterns.empty()) { return 1; } unsigned depth = 0; for (auto &c : nestedPatterns) { depth = std::max(depth, c.getDepth()); } return depth + 1; } /// Matches a single instruction in the following way: /// 1. checks the kind of instruction against the matcher, if different then /// there is no match; /// 2. calls the customizable filter function to refine the single instruction /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the nested patterns; /// 4. if all nested match then the single instruction matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. void NestedPattern::matchOne(Instruction *inst, SmallVectorImpl *matches) { if (skip == inst) { return; } // Local custom filter function if (!filter(*inst)) { return; } if (nestedPatterns.empty()) { SmallVector nestedMatches; matches->push_back(NestedMatch::build(inst, nestedMatches)); return; } // Take a copy of each nested pattern so we can match it. for (auto nestedPattern : nestedPatterns) { SmallVector nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. nestedPattern.skip = inst; nestedPattern.match(inst, &nestedMatches); // If we could not match even one of the specified nestedPattern, early exit // as this whole branch is not a match. if (nestedMatches.empty()) { return; } matches->push_back(NestedMatch::build(inst, nestedMatches)); } } static bool isAffineForOp(const Instruction &inst) { return inst.isa(); } static bool isAffineIfOp(const Instruction &inst) { return inst.isa(); } namespace mlir { namespace matcher { NestedPattern Op(FilterFunctionType filter) { return NestedPattern({}, filter); } NestedPattern If(NestedPattern child) { return NestedPattern(child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { return NestedPattern(child, [filter](const Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } NestedPattern If(ArrayRef nested) { return NestedPattern(nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(nested, [filter](const Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } NestedPattern For(NestedPattern child) { return NestedPattern(child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { return NestedPattern(child, [=](const Instruction &inst) { return isAffineForOp(inst) && filter(inst); }); } NestedPattern For(ArrayRef nested) { return NestedPattern(nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(nested, [=](const Instruction &inst) { return isAffineForOp(inst) && filter(inst); }); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { auto loop = inst.cast(); return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { auto loop = inst.cast(); return loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { return inst.isa() || inst.isa(); }; } // end namespace matcher } // end namespace mlir