diff options
Diffstat (limited to 'mlir/lib/Analysis/MLFunctionMatcher.cpp')
| -rw-r--r-- | mlir/lib/Analysis/MLFunctionMatcher.cpp | 80 |
1 files changed, 41 insertions, 39 deletions
diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp index 12ce8481516..5bb4548e670 100644 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp @@ -31,29 +31,29 @@ struct MLFunctionMatchesStorage { /// Underlying storage for MLFunctionMatcher. struct MLFunctionMatcherStorage { - MLFunctionMatcherStorage(Statement::Kind k, + MLFunctionMatcherStorage(Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> c, - FilterFunctionType filter, Statement *skip) + FilterFunctionType filter, Instruction *skip) : kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter), skip(skip) {} - Statement::Kind kind; + Instruction::Kind kind; SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers; FilterFunctionType filter; /// skip is needed so that we can implement match without switching on the - /// type of the Statement. + /// type of the Instruction. /// The idea is that a MLFunctionMatcher first checks if it matches locally /// and then recursively applies its children matchers to its elem->children. - /// Since we want to rely on the StmtWalker impl rather than duplicate its + /// Since we want to rely on the InstWalker impl rather than duplicate its /// the logic, we allow an off-by-one traversal to account for the fact that /// we write: /// - /// void match(Statement *elem) { + /// void match(Instruction *elem) { /// for (auto &c : getChildrenMLFunctionMatchers()) { /// MLFunctionMatcher childMLFunctionMatcher(...); /// ^~~~ Needs off-by-one skip. /// - Statement *skip; + Instruction *skip; }; } // end namespace mlir @@ -65,12 +65,12 @@ llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() { return allocator; } -void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) { +void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) { if (!storage) { storage = allocator()->Allocate<MLFunctionMatchesStorage>(); - new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children)); + new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children)); } else { - storage->matches.push_back(std::make_pair(stmt, children)); + storage->matches.push_back(std::make_pair(inst, children)); } } MLFunctionMatches::iterator MLFunctionMatches::begin() { @@ -98,10 +98,10 @@ MLFunctionMatches MLFunctionMatcher::match(Function *function) { return matches; } -/// Calls walk on `statement`. -MLFunctionMatches MLFunctionMatcher::match(Statement *statement) { +/// Calls walk on `instruction`. +MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) { assert(!matches && "MLFunctionMatcher already matched!"); - this->walkPostOrder(statement); + this->walkPostOrder(instruction); return matches; } @@ -117,17 +117,17 @@ unsigned MLFunctionMatcher::getDepth() { return depth + 1; } -/// Matches a single statement in the following way: -/// 1. checks the kind of statement against the matcher, if different then +/// Matches a single instruction in the following way: +/// 1. checks the kind of instruction against the matcher, if different then /// there is no match; -/// 2. calls the customizable filter function to refine the single statement +/// 2. calls the customizable filter function to refine the single instruction /// match with extra semantic constraints; /// 3. if all is good, recursivey matches the children patterns; -/// 4. if all children match then the single statement matches too and is +/// 4. if all children match then the single instruction matches too and is /// appended to the list of matches; /// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will /// want to traverse in post-order DFS to avoid invalidating iterators. -void MLFunctionMatcher::matchOne(Statement *elem) { +void MLFunctionMatcher::matchOne(Instruction *elem) { if (storage->skip == elem) { return; } @@ -159,7 +159,8 @@ llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() { return allocator; } -MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, +MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k, + MLFunctionMatcher child, FilterFunctionType filter) : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) { // Initialize with placement new. @@ -168,7 +169,7 @@ MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child, } MLFunctionMatcher::MLFunctionMatcher( - Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children, + Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children, FilterFunctionType filter) : storage(allocator()->Allocate<MLFunctionMatcherStorage>()) { // Initialize with placement new. @@ -178,14 +179,14 @@ MLFunctionMatcher::MLFunctionMatcher( MLFunctionMatcher MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl, - Statement *stmt) { + Instruction *inst) { MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(), tmpl.getFilterFunction()); - res.storage->skip = stmt; + res.storage->skip = inst; return res; } -Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; } +Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; } MutableArrayRef<MLFunctionMatcher> MLFunctionMatcher::getChildrenMLFunctionMatchers() { @@ -200,54 +201,55 @@ namespace mlir { namespace matcher { MLFunctionMatcher Op(FilterFunctionType filter) { - return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter); + return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter); } MLFunctionMatcher If(MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction); + return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction); } MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::If, child, filter); + return MLFunctionMatcher(Instruction::Kind::If, child, filter); } MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::If, children, + return MLFunctionMatcher(Instruction::Kind::If, children, defaultFilterFunction); } MLFunctionMatcher If(FilterFunctionType filter, MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::If, children, filter); + return MLFunctionMatcher(Instruction::Kind::If, children, filter); } MLFunctionMatcher For(MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction); + return MLFunctionMatcher(Instruction::Kind::For, child, + defaultFilterFunction); } MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) { - return MLFunctionMatcher(Statement::Kind::For, child, filter); + return MLFunctionMatcher(Instruction::Kind::For, child, filter); } MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::For, children, + return MLFunctionMatcher(Instruction::Kind::For, children, defaultFilterFunction); } MLFunctionMatcher For(FilterFunctionType filter, MutableArrayRef<MLFunctionMatcher> children) { - return MLFunctionMatcher(Statement::Kind::For, children, filter); + return MLFunctionMatcher(Instruction::Kind::For, children, filter); } // TODO(ntv): parallel annotation on loops. -bool isParallelLoop(const Statement &stmt) { - const auto *loop = cast<ForStmt>(&stmt); +bool isParallelLoop(const Instruction &inst) { + const auto *loop = cast<ForInst>(&inst); return (void *)loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. -bool isReductionLoop(const Statement &stmt) { - const auto *loop = cast<ForStmt>(&stmt); +bool isReductionLoop(const Instruction &inst) { + const auto *loop = cast<ForInst>(&inst); return (void *)loop || true; // loop->isReduction(); }; -bool isLoadOrStore(const Statement &stmt) { - const auto *opStmt = dyn_cast<OperationInst>(&stmt); - return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()); +bool isLoadOrStore(const Instruction &inst) { + const auto *opInst = dyn_cast<OperationInst>(&inst); + return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()); }; } // end namespace matcher |

