//===- NestedMatcher.cpp - NestedMatcher Impl ------------------*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedMatch NestedMatch::build(Instruction *instruction, ArrayRef nestedMatches) { auto *result = allocator()->Allocate(); auto *children = allocator()->Allocate(nestedMatches.size()); std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); new (result) NestedMatch(); result->matchedInstruction = instruction; result->matchedChildren = ArrayRef(children, nestedMatches.size()); return *result; } llvm::BumpPtrAllocator *&NestedPattern::allocator() { thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedPattern::NestedPattern(Instruction::Kind k, ArrayRef nested, FilterFunctionType filter) : kind(k), nestedPatterns(ArrayRef(nested)), filter(filter) { auto *newNested = allocator()->Allocate(nested.size()); std::uninitialized_copy(nested.begin(), nested.end(), newNested); nestedPatterns = ArrayRef(newNested, nested.size()); } unsigned NestedPattern::getDepth() const { if (nestedPatterns.empty()) { return 1; } unsigned depth = 0; for (auto &c : nestedPatterns) { depth = std::max(depth, c.getDepth()); } return depth + 1; } /// Matches a single instruction in the following way: /// 1. checks the kind of instruction against the matcher, if different then /// there is no match; /// 2. calls the customizable filter function to refine the single instruction /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the nested patterns; /// 4. if all nested match then the single instruction matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. void NestedPattern::matchOne(Instruction *inst, SmallVectorImpl *matches) { if (skip == inst) { return; } // Structural filter if (inst->getKind() != kind) { return; } // Local custom filter function if (!filter(*inst)) { return; } if (nestedPatterns.empty()) { SmallVector nestedMatches; matches->push_back(NestedMatch::build(inst, nestedMatches)); return; } // Take a copy of each nested pattern so we can match it. for (auto nestedPattern : nestedPatterns) { SmallVector nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walkPostOrder here. nestedPattern.skip = inst; nestedPattern.match(inst, &nestedMatches); // If we could not match even one of the specified nestedPattern, early exit // as this whole branch is not a match. if (nestedMatches.empty()) { return; } matches->push_back(NestedMatch::build(inst, nestedMatches)); } } static bool isAffineIfOp(const Instruction &inst) { return isa(inst) && cast(inst).isa(); } namespace mlir { namespace matcher { NestedPattern Op(FilterFunctionType filter) { return NestedPattern(Instruction::Kind::OperationInst, {}, filter); } NestedPattern If(NestedPattern child) { return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { return NestedPattern(Instruction::Kind::OperationInst, child, [filter](const Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } NestedPattern If(ArrayRef nested) { return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(Instruction::Kind::OperationInst, nested, [filter](const Instruction &inst) { return isAffineIfOp(inst) && filter(inst); }); } NestedPattern For(NestedPattern child) { return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { return NestedPattern(Instruction::Kind::For, child, filter); } NestedPattern For(ArrayRef nested) { return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(Instruction::Kind::For, nested, filter); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { const auto *loop = cast(&inst); return (void *)loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { const auto *loop = cast(&inst); return (void *)loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { const auto *opInst = dyn_cast(&inst); return opInst && (opInst->isa() || opInst->isa()); }; } // end namespace matcher } // end namespace mlir