diff options
| author | River Riddle <riverriddle@google.com> | 2019-02-04 16:24:44 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:12:59 -0700 |
| commit | bf9c381d1dbf4381659597109422e543d62a49d7 (patch) | |
| tree | 493aaa02d23039a9fcd31b9ff4b3a4f0af91df3a /mlir | |
| parent | c9ad4621ce2d68cad547da360aedeee733b73f32 (diff) | |
| download | bcm5719-llvm-bf9c381d1dbf4381659597109422e543d62a49d7.tar.gz bcm5719-llvm-bf9c381d1dbf4381659597109422e543d62a49d7.zip | |
Remove InstWalker and move all instruction walking to the api facilities on Function/Block/Instruction.
PiperOrigin-RevId: 232388113
Diffstat (limited to 'mlir')
25 files changed, 277 insertions, 455 deletions
diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 43ad0354b01..5ec536f47e7 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -229,14 +229,6 @@ public: /// (same operands in the same order). bool matchingBoundOperandList() const; - /// Walk the operation instructions in the 'for' instruction in preorder, - /// calling the callback for each operation. - void walk(std::function<void(Instruction *)> callback); - - /// Walk the operation instructions in the 'for' instruction in postorder, - /// calling the callback for each operation. - void walkPostOrder(std::function<void(Instruction *)> callback); - private: friend class Instruction; explicit AffineForOp(const Instruction *state) : Op(state) {} diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index aba0e11ab91..44fe4c0558a 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -18,7 +18,7 @@ #ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ #define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_ -#include "mlir/IR/InstVisitor.h" +#include "mlir/IR/Function.h" #include "llvm/Support/Allocator.h" namespace mlir { @@ -76,7 +76,7 @@ private: ArrayRef<NestedMatch> matchedChildren; }; -/// A NestedPattern is a nested InstWalker that: +/// A NestedPattern is a nested instruction walker that: /// 1. recursively matches a substructure in the tree; /// 2. uses a filter function to refine matches with extra semantic /// constraints (passed via a lambda of type FilterFunctionType); @@ -92,8 +92,8 @@ private: /// /// The NestedMatches captured in the IR can grow large, especially after /// aggressive unrolling. As experience has shown, it is generally better to use -/// a plain InstWalker to match flat patterns but the current implementation is -/// competitive nonetheless. +/// a plain walk over instructions to match flat patterns but the current +/// implementation is competitive nonetheless. using FilterFunctionType = std::function<bool(const Instruction &)>; static bool defaultFilterFunction(const Instruction &) { return true; }; struct NestedPattern { @@ -102,16 +102,14 @@ struct NestedPattern { NestedPattern(const NestedPattern &) = default; NestedPattern &operator=(const NestedPattern &) = default; - /// Returns all the top-level matches in `function`. - void match(Function *function, SmallVectorImpl<NestedMatch> *matches) { - State state(*this, matches); - state.walkPostOrder(function); + /// Returns all the top-level matches in `func`. + void match(Function *func, SmallVectorImpl<NestedMatch> *matches) { + func->walkPostOrder([&](Instruction *inst) { matchOne(inst, matches); }); } /// Returns all the top-level matches in `inst`. void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) { - State state(*this, matches); - state.walkPostOrder(inst); + inst->walkPostOrder([&](Instruction *child) { matchOne(child, matches); }); } /// Returns the depth of the pattern. @@ -120,22 +118,8 @@ struct NestedPattern { private: friend class NestedPatternContext; friend class NestedMatch; - friend class InstWalker<NestedPattern>; friend struct State; - /// Helper state that temporarily holds matches for the next level of nesting. - struct State : public InstWalker<State> { - State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches) - : pattern(pattern), matches(matches) {} - void visitInstruction(Instruction *opInst) { - pattern.matchOne(opInst, matches); - } - - private: - NestedPattern &pattern; - SmallVectorImpl<NestedMatch> *matches; - }; - /// Underlying global bump allocator managed by a NestedPatternContext. static llvm::BumpPtrAllocator *&allocator(); @@ -153,8 +137,9 @@ private: /// without switching on the type of the Instruction. The idea is that a /// NestedPattern first checks if it matches locally and then recursively /// applies its nested matchers to its elem->nested. Since we want to rely on - /// the InstWalker impl rather than duplicate its the logic, we allow an - /// off-by-one traversal to account for the fact that we write: + /// the existing instruction walking functionality rather than duplicate + /// it, we allow an off-by-one traversal to account for the fact that we + /// write: /// /// void match(Instruction *elem) { /// for (auto &c : getNestedPatterns()) { diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f3a2218d0f9..6e44770282b 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -288,6 +288,28 @@ public: llvm::iterator_range<succ_iterator> getSuccessors(); //===--------------------------------------------------------------------===// + // Instruction Walkers + //===--------------------------------------------------------------------===// + + /// Walk the instructions of this block in preorder, calling the callback for + /// each operation. + void walk(const std::function<void(Instruction *)> &callback); + + /// Walk the instructions in the specified [begin, end) range of + /// this block, calling the callback for each operation. + void walk(Block::iterator begin, Block::iterator end, + const std::function<void(Instruction *)> &callback); + + /// Walk the instructions in this block in postorder, calling the callback for + /// each operation. + void walkPostOrder(const std::function<void(Instruction *)> &callback); + + /// Walk the instructions in the specified [begin, end) range of this block + /// in postorder, calling the callback for each operation. + void walkPostOrder(Block::iterator begin, Block::iterator end, + const std::function<void(Instruction *)> &callback); + + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// @@ -311,19 +333,6 @@ public: return &Block::instructions; } - /// Walk the instructions of this block in preorder, calling the callback for - /// each operation. - void walk(std::function<void(Instruction *)> callback); - - /// Walk the instructions in this block in postorder, calling the callback for - /// each operation. - void walkPostOrder(std::function<void(Instruction *)> callback); - - /// Walk the instructions in the specified [begin, end) range of - /// this block, calling the callback for each operation. - void walk(Block::iterator begin, Block::iterator end, - std::function<void(Instruction *)> callback); - void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index f483ff46259..3afb021c8ec 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -27,6 +27,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/Instruction.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" @@ -39,6 +40,7 @@ class FunctionType; class MLIRContext; class Module; template <typename ObjectType, typename ElementType> class ArgumentIterator; +template <typename T> class OpPointer; /// NamedAttribute is used for function attribute lists, it holds an /// identifier for the name and a value for the attribute. The attribute @@ -115,13 +117,35 @@ public: Block &front() { return blocks.front(); } const Block &front() const { return const_cast<Function *>(this)->front(); } + //===--------------------------------------------------------------------===// + // Instruction Walkers + //===--------------------------------------------------------------------===// + /// Walk the instructions in the function in preorder, calling the callback - /// for each instruction or operation. - void walk(std::function<void(Instruction *)> callback); + /// for each instruction. + void walk(const std::function<void(Instruction *)> &callback); + + /// Specialization of walk to only visit operations of 'OpTy'. + template <typename OpTy> + void walk(std::function<void(OpPointer<OpTy>)> callback) { + walk([&](Instruction *inst) { + if (auto op = inst->dyn_cast<OpTy>()) + callback(op); + }); + } /// Walk the instructions in the function in postorder, calling the callback - /// for each instruction or operation. - void walkPostOrder(std::function<void(Instruction *)> callback); + /// for each instruction. + void walkPostOrder(const std::function<void(Instruction *)> &callback); + + /// Specialization of walkPostOrder to only visit operations of 'OpTy'. + template <typename OpTy> + void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) { + walkPostOrder([&](Instruction *inst) { + if (auto op = inst->dyn_cast<OpTy>()) + callback(op); + }); + } //===--------------------------------------------------------------------===// // Arguments diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h deleted file mode 100644 index e11b7350894..00000000000 --- a/mlir/include/mlir/IR/InstVisitor.h +++ /dev/null @@ -1,140 +0,0 @@ -//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines the base classes for Function's instruction visitors and -// walkers. A visit is a O(1) operation that visits just the node in question. A -// walk visits the node it's called on as well as the node's descendants. -// -// Instruction visitors/walkers are used when you want to perform different -// actions for different kinds of instructions without having to use lots of -// casts and a big switch instruction. -// -// To define your own visitor/walker, inherit from these classes, specifying -// your new type for the 'SubClass' template parameter, and "override" visitXXX -// functions in your class. This class is defined in terms of statically -// resolved overloading, not virtual functions. -// -// For example, here is a walker that counts the number of for loops in an -// Function. -// -// /// Declare the class. Note that we derive from InstWalker instantiated -// /// with _our new subclasses_ type. -// struct LoopCounter : public InstWalker<LoopCounter> { -// unsigned numLoops; -// LoopCounter() : numLoops(0) {} -// void visitForInst(ForInst &fs) { ++numLoops; } -// }; -// -// And this class would be used like this: -// LoopCounter lc; -// lc.walk(function); -// numLoops = lc.numLoops; -// -// There are 'visit' methods for Instruction and Function, which recursively -// process all contained instructions. -// -// Note that if you don't implement visitXXX for some instruction type, -// the visitXXX method for Instruction superclass will be invoked. -// -// The optional second template argument specifies the type that instruction -// visitation functions should return. If you specify this, you *MUST* provide -// an implementation of every visit<#Instruction>(InstType *). -// -// Note that these classes are specifically designed as a template to avoid -// virtual function call overhead. - -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_INSTVISITOR_H -#define MLIR_IR_INSTVISITOR_H - -#include "mlir/IR/Function.h" -#include "mlir/IR/Instruction.h" - -namespace mlir { -/// Base class for instruction walkers. A walker can traverse depth first in -/// pre-order or post order. The walk methods without a suffix do a pre-order -/// traversal while those that traverse in post order have a PostOrder suffix. -template <typename SubClass, typename RetTy = void> class InstWalker { - //===--------------------------------------------------------------------===// - // Interface code - This is the public interface of the InstWalker used to - // walk instructions. - -public: - // Generic walk method - allow walk to all instructions in a range. - template <class Iterator> void walk(Iterator Start, Iterator End) { - while (Start != End) { - walk(&(*Start++)); - } - } - template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) { - while (Start != End) { - walkPostOrder(&(*Start++)); - } - } - - // Define walkers for Function and all Function instruction kinds. - void walk(Function *f) { - for (auto &block : *f) - static_cast<SubClass *>(this)->walk(block.begin(), block.end()); - } - - void walkPostOrder(Function *f) { - for (auto it = f->rbegin(), e = f->rend(); it != e; ++it) - static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end()); - } - - // Function to walk a instruction. - RetTy walk(Instruction *s) { - static_assert(std::is_base_of<InstWalker, SubClass>::value, - "Must pass the derived type to this template!"); - - static_cast<SubClass *>(this)->visitInstruction(s); - for (auto &blockList : s->getBlockLists()) - for (auto &block : blockList) - static_cast<SubClass *>(this)->walk(block.begin(), block.end()); - } - - // Function to walk a instruction in post order DFS. - RetTy walkPostOrder(Instruction *s) { - static_assert(std::is_base_of<InstWalker, SubClass>::value, - "Must pass the derived type to this template!"); - for (auto &blockList : s->getBlockLists()) - for (auto &block : blockList) - static_cast<SubClass *>(this)->walkPostOrder(block.begin(), - block.end()); - static_cast<SubClass *>(this)->visitInstruction(s); - } - - //===--------------------------------------------------------------------===// - // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular instruction type. - // The default behavior is to generalize the instruction type to its subtype - // and try visiting the subtype. All of this should be inlined perfectly, - // because there are no virtual functions to get in the way. - - // When visiting a specific inst directly during a walk, these methods get - // called. These are typically O(1) complexity and shouldn't be recursively - // processing their descendants in some way. When using RetTy, all of these - // need to be overridden. - void visitInstruction(Instruction *inst) {} -}; - -} // end namespace mlir - -#endif // MLIR_IR_INSTVISITOR_H diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index c8a1dc8a7bb..bbd0ba10d65 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -614,6 +614,36 @@ public: } //===--------------------------------------------------------------------===// + // Instruction Walkers + //===--------------------------------------------------------------------===// + + /// Walk the instructions held by this instruction in preorder, calling the + /// callback for each instruction. + void walk(const std::function<void(Instruction *)> &callback); + + /// Specialization of walk to only visit operations of 'OpTy'. + template <typename OpTy> + void walk(std::function<void(OpPointer<OpTy>)> callback) { + walk([&](Instruction *inst) { + if (auto op = inst->dyn_cast<OpTy>()) + callback(op); + }); + } + + /// Walk the instructions held by this function in postorder, calling the + /// callback for each instruction. + void walkPostOrder(const std::function<void(Instruction *)> &callback); + + /// Specialization of walkPostOrder to only visit operations of 'OpTy'. + template <typename OpTy> + void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) { + walkPostOrder([&](Instruction *inst) { + if (auto op = inst->dyn_cast<OpTy>()) + callback(op); + }); + } + + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 39345d7fc7a..c3adf5fb7c3 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -646,32 +645,6 @@ bool AffineForOp::matchingBoundOperandList() const { return true; } -void AffineForOp::walk(std::function<void(Instruction *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(Instruction *)> const &callback; - Walker(std::function<void(Instruction *)> const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(getInstruction()); -} - -void AffineForOp::walkPostOrder(std::function<void(Instruction *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(Instruction *)> const &callback; - Walker(std::function<void(Instruction *)> const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(getInstruction()); -} - /// Returns the induction variable for this loop. Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); } diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index ab22f261a3b..3376cd7d512 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -26,7 +26,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,13 +37,11 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> { +struct MemRefBoundCheck : public FunctionPass { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); - static char passID; }; @@ -56,17 +53,16 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitInstruction(Instruction *opInst) { - if (auto loadOp = opInst->dyn_cast<LoadOp>()) { - boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { - boundCheckLoadOrStoreOp(storeOp); - } - // TODO(bondhugula): do this for DMA ops as well. -} - PassResult MemRefBoundCheck::runOnFunction(Function *f) { - return walk(f), success(); + f->walk([](Instruction *opInst) { + if (auto loadOp = opInst->dyn_cast<LoadOp>()) { + boundCheckLoadOrStoreOp(loadOp); + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + boundCheckLoadOrStoreOp(storeOp); + } + // TODO(bondhugula): do this for DMA ops as well. + }); + return success(); } static PassRegistration<MemRefBoundCheck> diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 6ea47a20f60..9ec1c95f213 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,19 +37,13 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. -struct MemRefDependenceCheck : public FunctionPass, - InstWalker<MemRefDependenceCheck> { +struct MemRefDependenceCheck : public FunctionPass { SmallVector<Instruction *, 4> loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst) { - if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) { - loadsAndStores.push_back(opInst); - } - } static char passID; }; @@ -120,8 +113,13 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. PassResult MemRefDependenceCheck::runOnFunction(Function *f) { + // Collect the loads and stores within the function. loadsAndStores.clear(); - walk(f); + f->walk([&](Instruction *inst) { + if (inst->isa<LoadOp>() || inst->isa<StoreOp>()) + loadsAndStores.push_back(inst); + }); + checkDependences(loadsAndStores); return success(); } diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 742c0baa96b..f05f8737b16 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -15,7 +15,6 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" @@ -27,16 +26,13 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> { +struct PrintOpStatsPass : public ModulePass { explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : ModulePass(&PrintOpStatsPass::passID), os(os) {} // Prints the resultant operation statistics post iterating over the module. PassResult runOnModule(Module *m) override; - // Updates the operation statistics for the given instruction. - void visitInstruction(Instruction *inst); - // Print summary of op stats. void printSummary(); @@ -44,7 +40,6 @@ struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> { private: llvm::StringMap<int64_t> opCount; - llvm::raw_ostream &os; }; } // namespace @@ -52,16 +47,16 @@ private: char PrintOpStatsPass::passID = 0; PassResult PrintOpStatsPass::runOnModule(Module *m) { + opCount.clear(); + + // Compute the operation statistics for each function in the module. for (auto &fn : *m) - walk(&fn); + fn.walk( + [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); printSummary(); return success(); } -void PrintOpStatsPass::visitInstruction(Instruction *inst) { - ++opCount[inst->getName().getStringRef()]; -} - void PrintOpStatsPass::printSummary() { os << "Operations encountered:\n"; os << "-----------------------\n"; diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 1703a16c2b8..cea99121dc9 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -19,7 +19,6 @@ #include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a69920cbd86..ffc863d76d0 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index f18ce8e33a8..e6dfc4c2145 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -18,7 +18,6 @@ #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" using namespace mlir; @@ -227,6 +226,34 @@ Block *Block::getSinglePredecessor() { } //===----------------------------------------------------------------------===// +// Instruction Walkers +//===----------------------------------------------------------------------===// + +void Block::walk(const std::function<void(Instruction *)> &callback) { + walk(begin(), end(), callback); +} + +void Block::walk(Block::iterator begin, Block::iterator end, + const std::function<void(Instruction *)> &callback) { + // Walk the instructions within this block. + for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end))) + inst.walk(callback); +} + +void Block::walkPostOrder(const std::function<void(Instruction *)> &callback) { + walkPostOrder(begin(), end(), callback); +} + +/// Walk the instructions in the specified [begin, end) range of this block +/// in postorder, calling the callback for each operation. +void Block::walkPostOrder(Block::iterator begin, Block::iterator end, + const std::function<void(Instruction *)> &callback) { + // Walk the instructions within this block. + for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end))) + inst.walkPostOrder(callback); +} + +//===----------------------------------------------------------------------===// // Other //===----------------------------------------------------------------------===// @@ -253,37 +280,6 @@ Block *Block::splitBlock(iterator splitBefore) { return newBB; } -void Block::walk(std::function<void(Instruction *)> callback) { - walk(begin(), end(), callback); -} - -void Block::walk(Block::iterator begin, Block::iterator end, - std::function<void(Instruction *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(Instruction *)> const &callback; - Walker(std::function<void(Instruction *)> const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker w(callback); - w.walk(begin, end); -} - -void Block::walkPostOrder(std::function<void(Instruction *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(Instruction *)> const &callback; - Walker(std::function<void(Instruction *)> const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *opInst) { callback(opInst); } - }; - - Walker v(callback); - v.walkPostOrder(begin(), end()); -} - //===----------------------------------------------------------------------===// // BlockList //===----------------------------------------------------------------------===// @@ -331,25 +327,18 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper, // Now that each of the blocks have been cloned, go through and remap the // operands of each of the instructions. - struct Walker : public InstWalker<Walker> { - BlockAndValueMapping &mapper; - Walker(BlockAndValueMapping &mapper) : mapper(mapper) {} - - /// Remap the instruction and successor block operands. - void visitInstruction(Instruction *inst) { - for (auto &instOp : inst->getInstOperands()) - if (auto *mappedOp = mapper.lookupOrNull(instOp.get())) - instOp.set(mappedOp); - if (inst->isTerminator()) - for (auto &succOp : inst->getBlockOperands()) - if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) - succOp.set(mappedOp); - } + auto remapOperands = [&](Instruction *inst) { + for (auto &instOp : inst->getInstOperands()) + if (auto *mappedOp = mapper.lookupOrNull(instOp.get())) + instOp.set(mappedOp); + if (inst->isTerminator()) + for (auto &succOp : inst->getBlockOperands()) + if (auto *mappedOp = mapper.lookupOrNull(succOp.get())) + succOp.set(mappedOp); }; - Walker v(mapper); for (auto it = std::next(lastOldBlock), e = dest->end(); it != e; ++it) - v.walk(it->begin(), it->end()); + it->walk(remapOperands); } BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() { diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 3a263fb13f9..ba781500c4f 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -19,7 +19,6 @@ #include "AttributeListStorage.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" @@ -214,28 +213,15 @@ void Function::addEntryBlock() { entry->addArguments(type.getInputs()); } -void Function::walk(std::function<void(Instruction *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(Instruction *)> const &callback; - Walker(std::function<void(Instruction *)> const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *inst) { callback(inst); } - }; - - Walker v(callback); - v.walk(this); +void Function::walk(const std::function<void(Instruction *)> &callback) { + // Walk each of the blocks within the function. + for (auto &block : getBlocks()) + block.walk(callback); } -void Function::walkPostOrder(std::function<void(Instruction *)> callback) { - struct Walker : public InstWalker<Walker> { - std::function<void(Instruction *)> const &callback; - Walker(std::function<void(Instruction *)> const &callback) - : callback(callback) {} - - void visitInstruction(Instruction *inst) { callback(inst); } - }; - - Walker v(callback); - v.walkPostOrder(this); +void Function::walkPostOrder( + const std::function<void(Instruction *)> &callback) { + // Walk each of the blocks within the function. + for (auto &block : llvm::reverse(getBlocks())) + block.walkPostOrder(callback); } diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6720969ac0f..062f13a3282 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/DenseMap.h" @@ -300,6 +299,35 @@ Function *Instruction::getFunction() const { return block ? block->getFunction() : nullptr; } +//===----------------------------------------------------------------------===// +// Instruction Walkers +//===----------------------------------------------------------------------===// + +void Instruction::walk(const std::function<void(Instruction *)> &callback) { + // Visit the current instruction. + callback(this); + + // Visit any internal instructions. + for (auto &blockList : getBlockLists()) + for (auto &block : blockList) + block.walk(callback); +} + +void Instruction::walkPostOrder( + const std::function<void(Instruction *)> &callback) { + // Visit any internal instructions. + for (auto &blockList : llvm::reverse(getBlockLists())) + for (auto &block : llvm::reverse(blockList)) + block.walkPostOrder(callback); + + // Visit the current instruction. + callback(this); +} + +//===----------------------------------------------------------------------===// +// Other +//===----------------------------------------------------------------------===// + /// Emit a note about this instruction, reporting up to any diagnostic /// handlers that may be listening. void Instruction::emitNote(const Twine &message) const { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index ae77d66b183..b7e4fb147cb 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -26,7 +26,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 63a676d7b52..de10fe8a461 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 289b00d3b51..796477c64f2 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -46,10 +45,9 @@ namespace { // result of any AffineApplyOp). After this composition, AffineApplyOps with no // remaining uses are erased. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. -struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> { +struct ComposeAffineMaps : public FunctionPass { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps; @@ -68,15 +66,11 @@ static bool affineApplyOp(const Instruction &inst) { return inst.isa<AffineApplyOp>(); } -void ComposeAffineMaps::visitInstruction(Instruction *opInst) { - if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) - affineApplyOps.push_back(afOp); -} - PassResult ComposeAffineMaps::runOnFunction(Function *f) { // If needed for future efficiency, reserve space based on a pre-walk. affineApplyOps.clear(); - walk(f); + f->walk<AffineApplyOp>( + [&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); }); for (auto afOp : affineApplyOps) { SmallVector<Value *, 8> operands(afOp->getOperands()); FuncBuilder b(afOp->getInstruction()); @@ -87,7 +81,8 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) { // Erase dead affine apply ops. affineApplyOps.clear(); - walk(f); + f->walk<AffineApplyOp>( + [&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); }); for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) { if ((*it)->use_empty()) { (*it)->erase(); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 54486cdb293..e41ac0ad329 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -18,7 +18,6 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -27,7 +26,7 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { +struct ConstantFold : public FunctionPass { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. @@ -35,9 +34,7 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { // Operations that were folded and that need to be erased. std::vector<Instruction *> opInstsToErase; - bool foldOperation(Instruction *op, - SmallVectorImpl<Value *> &existingConstants); - void visitInstruction(Instruction *op); + void foldInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +46,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitInstruction(Instruction *op) { +void ConstantFold::foldInstruction(Instruction *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast<AffineForOp>()) { constantFoldBounds(forOp); @@ -111,7 +108,7 @@ PassResult ConstantFold::runOnFunction(Function *f) { existingConstants.clear(); opInstsToErase.clear(); - walk(f); + f->walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7a002168528..77e5a6aa04f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -111,22 +110,23 @@ namespace { // LoopNestStateCollector walks loop nests and collects load and store // operations, and whether or not an IfInst was encountered in the loop nest. -class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> { -public: +struct LoopNestStateCollector { SmallVector<OpPointer<AffineForOp>, 4> forOps; SmallVector<Instruction *, 4> loadOpInsts; SmallVector<Instruction *, 4> storeOpInsts; bool hasNonForRegion = false; - void visitInstruction(Instruction *opInst) { - if (opInst->isa<AffineForOp>()) - forOps.push_back(opInst->cast<AffineForOp>()); - else if (opInst->getNumBlockLists() != 0) - hasNonForRegion = true; - else if (opInst->isa<LoadOp>()) - loadOpInsts.push_back(opInst); - else if (opInst->isa<StoreOp>()) - storeOpInsts.push_back(opInst); + void collect(Instruction *instToWalk) { + instToWalk->walk([&](Instruction *opInst) { + if (opInst->isa<AffineForOp>()) + forOps.push_back(opInst->cast<AffineForOp>()); + else if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa<LoadOp>()) + loadOpInsts.push_back(opInst); + else if (opInst->isa<StoreOp>()) + storeOpInsts.push_back(opInst); + }); } }; @@ -510,7 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walk(&inst); + collector.collect(&inst); // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; @@ -606,41 +606,39 @@ struct LoopNestStats { // LoopNestStatsCollector walks a single loop nest and gathers per-loop // trip count and operation count statistics and records them in 'stats'. -class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> { -public: +struct LoopNestStatsCollector { LoopNestStats *stats; bool hasLoopWithNonConstTripCount = false; LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitInstruction(Instruction *opInst) { - auto forOp = opInst->dyn_cast<AffineForOp>(); - if (!forOp) - return; - - auto *forInst = forOp->getInstruction(); - auto *parentInst = forOp->getInstruction()->getParentInst(); - if (parentInst != nullptr) { - assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp"); - // Add mapping to 'forOp' from its parent AffineForOp. - stats->loopMap[parentInst].push_back(forOp); - } + void collect(Instruction *inst) { + inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) { + auto *forInst = forOp->getInstruction(); + auto *parentInst = forOp->getInstruction()->getParentInst(); + if (parentInst != nullptr) { + assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp"); + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentInst].push_back(forOp); + } - // Record the number of op instructions in the body of 'forOp'. - unsigned count = 0; - stats->opCountMap[forInst] = 0; - for (auto &inst : *forOp->getBody()) { - if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>())) - ++count; - } - stats->opCountMap[forInst] = count; - // Record trip count for 'forOp'. Set flag if trip count is not constant. - Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); - if (!maybeConstTripCount.hasValue()) { - hasLoopWithNonConstTripCount = true; - return; - } - stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + // Record the number of op instructions in the body of 'forOp'. + unsigned count = 0; + stats->opCountMap[forInst] = 0; + for (auto &inst : *forOp->getBody()) { + if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>())) + ++count; + } + stats->opCountMap[forInst] = count; + // Record trip count for 'forOp'. Set flag if trip count is not + // constant. + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); + if (!maybeConstTripCount.hasValue()) { + hasLoopWithNonConstTripCount = true; + return; + } + stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + }); } }; @@ -1078,7 +1076,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.walk(srcLoopIVs[0]->getInstruction()); + srcStatsCollector.collect(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1089,7 +1087,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.walk(dstLoopIVs[0]->getInstruction()); + dstStatsCollector.collect(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1474,7 +1472,7 @@ public: // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.walk(sliceLoopNest->getInstruction()); + sliceCollector.collect(sliceLoopNest->getInstruction()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -1498,7 +1496,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.walk(dstAffineForOp->getInstruction()); + dstLoopCollector.collect(dstAffineForOp->getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index b1e15ccb07b..3a7cfb85e08 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -95,15 +94,16 @@ char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> { - public: + struct InnermostLoopGatherer { // Store innermost loops as we walk. std::vector<OpPointer<AffineForOp>> loops; - // This method specialized to encode custom return logic. - using InstListType = llvm::iplist<Instruction>; - bool walkPostOrder(InstListType::iterator Start, - InstListType::iterator End) { + void walkPostOrder(Function *f) { + for (auto &b : *f) + walkPostOrder(b.begin(), b.end()); + } + + bool walkPostOrder(Block::iterator Start, Block::iterator End) { bool hasInnerLoops = false; // We need to walk all elements since all innermost loops need to be // gathered as opposed to determining whether this list has any inner @@ -112,7 +112,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { hasInnerLoops |= walkPostOrder(&(*Start++)); return hasInnerLoops; } - bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) @@ -125,39 +124,21 @@ PassResult LoopUnroll::runOnFunction(Function *f) { } return hasInnerLoops; } - - // FIXME: can't use base class method for this because that in turn would - // need to use the derived class method above. CRTP doesn't allow it, and - // the compiler error resulting from it is also misleading. - using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder; }; - // Gathers all loops with trip count <= minTripCount. - class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> { - public: + if (clUnrollFull.getNumOccurrences() > 0 && + clUnrollFullThreshold.getNumOccurrences() > 0) { // Store short loops as we walk. std::vector<OpPointer<AffineForOp>> loops; - const unsigned minTripCount; - ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitInstruction(Instruction *opInst) { - auto forOp = opInst->dyn_cast<AffineForOp>(); - if (!forOp) - return; + // Gathers all loops with trip count <= minTripCount. Do a post order walk + // so that loops are gathered from innermost to outermost (or else unrolling + // an outer one may delete gathered inner ones). + f->walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) { Optional<uint64_t> tripCount = getConstantTripCount(forOp); - if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) + if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); - } - }; - - if (clUnrollFull.getNumOccurrences() > 0 && - clUnrollFullThreshold.getNumOccurrences() > 0) { - ShortLoopGatherer slg(clUnrollFullThreshold); - // Do a post order walk so that loops are gathered from innermost to - // outermost (or else unrolling an outer one may delete gathered inner - // ones). - slg.walkPostOrder(f); - auto &loops = slg.loops; + }); for (auto forOp : loops) loopUnrollFull(forOp); return success(); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 74c54fde047..b2aed7d9d7f 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -50,7 +50,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -136,24 +135,25 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp, // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). - class JamBlockGatherer : public InstWalker<JamBlockGatherer> { - public: - using InstListType = llvm::iplist<Instruction>; - using InstWalker<JamBlockGatherer>::walk; - + struct JamBlockGatherer { // Store iterators to the first and last inst of each sub-block found. std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; // This is a linear time walk. - void walk(InstListType::iterator Start, InstListType::iterator End) { - for (auto it = Start; it != End;) { + void walk(Instruction *inst) { + for (auto &blockList : inst->getBlockLists()) + for (auto &block : blockList) + walk(block); + } + void walk(Block &block) { + for (auto it = block.begin(), e = block.end(); it != e;) { auto subBlockStart = it; - while (it != End && !it->isa<AffineForOp>()) + while (it != e && !it->isa<AffineForOp>()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && it->isa<AffineForOp>()) + while (it != e && it->isa<AffineForOp>()) walk(&*it++); } } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 2d06a327315..9c9db30d163 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -70,12 +69,12 @@ namespace { // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> { +struct MemRefDataFlowOpt : public FunctionPass { explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); + void forwardStoreToLoad(OpPointer<LoadOp> loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet<Value *, 4> memrefsToErase; @@ -100,14 +99,9 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { +void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer<LoadOp> loadOp) { Instruction *lastWriteStoreOp = nullptr; - - auto loadOp = opInst->dyn_cast<LoadOp>(); - if (!loadOp) - return; - - Instruction *loadOpInst = opInst; + Instruction *loadOpInst = loadOp->getInstruction(); // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across @@ -235,7 +229,8 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - walk(f); + f->walk<LoadOp>( + [&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index ba3be5e95f4..4ca48a53485 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -142,10 +142,8 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder([&](Instruction *opInst) { - if (auto forOp = opInst->dyn_cast<AffineForOp>()) - forOps.push_back(forOp); - }); + f->walkPostOrder<AffineForOp>( + [&](OpPointer<AffineForOp> forOp) { forOps.push_back(forOp); }); bool ret = false; for (auto forOp : forOps) { ret = ret | runOnAffineForOp(forOp); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5bf17989bef..95875adca6e 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" @@ -135,10 +134,8 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkPostOrder([](Instruction *inst) { - if (auto forOp = inst->dyn_cast<AffineForOp>()) - promoteIfSingleIteration(forOp); - }); + f->walkPostOrder<AffineForOp>( + [](OpPointer<AffineForOp> forOp) { promoteIfSingleIteration(forOp); }); } /// Generates a 'for' inst with the specified lower and upper bounds while |

