summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-02-04 16:24:44 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:12:59 -0700
commitbf9c381d1dbf4381659597109422e543d62a49d7 (patch)
tree493aaa02d23039a9fcd31b9ff4b3a4f0af91df3a /mlir
parentc9ad4621ce2d68cad547da360aedeee733b73f32 (diff)
downloadbcm5719-llvm-bf9c381d1dbf4381659597109422e543d62a49d7.tar.gz
bcm5719-llvm-bf9c381d1dbf4381659597109422e543d62a49d7.zip
Remove InstWalker and move all instruction walking to the api facilities on Function/Block/Instruction.
PiperOrigin-RevId: 232388113
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/AffineOps/AffineOps.h8
-rw-r--r--mlir/include/mlir/Analysis/NestedMatcher.h37
-rw-r--r--mlir/include/mlir/IR/Block.h35
-rw-r--r--mlir/include/mlir/IR/Function.h32
-rw-r--r--mlir/include/mlir/IR/InstVisitor.h140
-rw-r--r--mlir/include/mlir/IR/Instruction.h30
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp27
-rw-r--r--mlir/lib/Analysis/MemRefBoundCheck.cpp24
-rw-r--r--mlir/lib/Analysis/MemRefDependenceCheck.cpp16
-rw-r--r--mlir/lib/Analysis/OpStats.cpp17
-rw-r--r--mlir/lib/EDSC/LowerEDSCTestPass.cpp1
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp1
-rw-r--r--mlir/lib/IR/Block.cpp85
-rw-r--r--mlir/lib/IR/Function.cpp32
-rw-r--r--mlir/lib/IR/Instruction.cpp30
-rw-r--r--mlir/lib/Parser/Parser.cpp1
-rw-r--r--mlir/lib/Transforms/CSE.cpp1
-rw-r--r--mlir/lib/Transforms/ComposeAffineMaps.cpp15
-rw-r--r--mlir/lib/Transforms/ConstantFold.cpp11
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp90
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp49
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp20
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp17
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp6
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp7
25 files changed, 277 insertions, 455 deletions
diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h
index 43ad0354b01..5ec536f47e7 100644
--- a/mlir/include/mlir/AffineOps/AffineOps.h
+++ b/mlir/include/mlir/AffineOps/AffineOps.h
@@ -229,14 +229,6 @@ public:
/// (same operands in the same order).
bool matchingBoundOperandList() const;
- /// Walk the operation instructions in the 'for' instruction in preorder,
- /// calling the callback for each operation.
- void walk(std::function<void(Instruction *)> callback);
-
- /// Walk the operation instructions in the 'for' instruction in postorder,
- /// calling the callback for each operation.
- void walkPostOrder(std::function<void(Instruction *)> callback);
-
private:
friend class Instruction;
explicit AffineForOp(const Instruction *state) : Op(state) {}
diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h
index aba0e11ab91..44fe4c0558a 100644
--- a/mlir/include/mlir/Analysis/NestedMatcher.h
+++ b/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -18,7 +18,7 @@
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
-#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/Function.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
@@ -76,7 +76,7 @@ private:
ArrayRef<NestedMatch> matchedChildren;
};
-/// A NestedPattern is a nested InstWalker that:
+/// A NestedPattern is a nested instruction walker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
@@ -92,8 +92,8 @@ private:
///
/// The NestedMatches captured in the IR can grow large, especially after
/// aggressive unrolling. As experience has shown, it is generally better to use
-/// a plain InstWalker to match flat patterns but the current implementation is
-/// competitive nonetheless.
+/// a plain walk over instructions to match flat patterns but the current
+/// implementation is competitive nonetheless.
using FilterFunctionType = std::function<bool(const Instruction &)>;
static bool defaultFilterFunction(const Instruction &) { return true; };
struct NestedPattern {
@@ -102,16 +102,14 @@ struct NestedPattern {
NestedPattern(const NestedPattern &) = default;
NestedPattern &operator=(const NestedPattern &) = default;
- /// Returns all the top-level matches in `function`.
- void match(Function *function, SmallVectorImpl<NestedMatch> *matches) {
- State state(*this, matches);
- state.walkPostOrder(function);
+ /// Returns all the top-level matches in `func`.
+ void match(Function *func, SmallVectorImpl<NestedMatch> *matches) {
+ func->walkPostOrder([&](Instruction *inst) { matchOne(inst, matches); });
}
/// Returns all the top-level matches in `inst`.
void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) {
- State state(*this, matches);
- state.walkPostOrder(inst);
+ inst->walkPostOrder([&](Instruction *child) { matchOne(child, matches); });
}
/// Returns the depth of the pattern.
@@ -120,22 +118,8 @@ struct NestedPattern {
private:
friend class NestedPatternContext;
friend class NestedMatch;
- friend class InstWalker<NestedPattern>;
friend struct State;
- /// Helper state that temporarily holds matches for the next level of nesting.
- struct State : public InstWalker<State> {
- State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
- : pattern(pattern), matches(matches) {}
- void visitInstruction(Instruction *opInst) {
- pattern.matchOne(opInst, matches);
- }
-
- private:
- NestedPattern &pattern;
- SmallVectorImpl<NestedMatch> *matches;
- };
-
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
@@ -153,8 +137,9 @@ private:
/// without switching on the type of the Instruction. The idea is that a
/// NestedPattern first checks if it matches locally and then recursively
/// applies its nested matchers to its elem->nested. Since we want to rely on
- /// the InstWalker impl rather than duplicate its the logic, we allow an
- /// off-by-one traversal to account for the fact that we write:
+ /// the existing instruction walking functionality rather than duplicate
+ /// it, we allow an off-by-one traversal to account for the fact that we
+ /// write:
///
/// void match(Instruction *elem) {
/// for (auto &c : getNestedPatterns()) {
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index f3a2218d0f9..6e44770282b 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -288,6 +288,28 @@ public:
llvm::iterator_range<succ_iterator> getSuccessors();
//===--------------------------------------------------------------------===//
+ // Instruction Walkers
+ //===--------------------------------------------------------------------===//
+
+ /// Walk the instructions of this block in preorder, calling the callback for
+ /// each operation.
+ void walk(const std::function<void(Instruction *)> &callback);
+
+ /// Walk the instructions in the specified [begin, end) range of
+ /// this block, calling the callback for each operation.
+ void walk(Block::iterator begin, Block::iterator end,
+ const std::function<void(Instruction *)> &callback);
+
+ /// Walk the instructions in this block in postorder, calling the callback for
+ /// each operation.
+ void walkPostOrder(const std::function<void(Instruction *)> &callback);
+
+ /// Walk the instructions in the specified [begin, end) range of this block
+ /// in postorder, calling the callback for each operation.
+ void walkPostOrder(Block::iterator begin, Block::iterator end,
+ const std::function<void(Instruction *)> &callback);
+
+ //===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
@@ -311,19 +333,6 @@ public:
return &Block::instructions;
}
- /// Walk the instructions of this block in preorder, calling the callback for
- /// each operation.
- void walk(std::function<void(Instruction *)> callback);
-
- /// Walk the instructions in this block in postorder, calling the callback for
- /// each operation.
- void walkPostOrder(std::function<void(Instruction *)> callback);
-
- /// Walk the instructions in the specified [begin, end) range of
- /// this block, calling the callback for each operation.
- void walk(Block::iterator begin, Block::iterator end,
- std::function<void(Instruction *)> callback);
-
void print(raw_ostream &os) const;
void dump() const;
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index f483ff46259..3afb021c8ec 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -27,6 +27,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Instruction.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
@@ -39,6 +40,7 @@ class FunctionType;
class MLIRContext;
class Module;
template <typename ObjectType, typename ElementType> class ArgumentIterator;
+template <typename T> class OpPointer;
/// NamedAttribute is used for function attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
@@ -115,13 +117,35 @@ public:
Block &front() { return blocks.front(); }
const Block &front() const { return const_cast<Function *>(this)->front(); }
+ //===--------------------------------------------------------------------===//
+ // Instruction Walkers
+ //===--------------------------------------------------------------------===//
+
/// Walk the instructions in the function in preorder, calling the callback
- /// for each instruction or operation.
- void walk(std::function<void(Instruction *)> callback);
+ /// for each instruction.
+ void walk(const std::function<void(Instruction *)> &callback);
+
+ /// Specialization of walk to only visit operations of 'OpTy'.
+ template <typename OpTy>
+ void walk(std::function<void(OpPointer<OpTy>)> callback) {
+ walk([&](Instruction *inst) {
+ if (auto op = inst->dyn_cast<OpTy>())
+ callback(op);
+ });
+ }
/// Walk the instructions in the function in postorder, calling the callback
- /// for each instruction or operation.
- void walkPostOrder(std::function<void(Instruction *)> callback);
+ /// for each instruction.
+ void walkPostOrder(const std::function<void(Instruction *)> &callback);
+
+ /// Specialization of walkPostOrder to only visit operations of 'OpTy'.
+ template <typename OpTy>
+ void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) {
+ walkPostOrder([&](Instruction *inst) {
+ if (auto op = inst->dyn_cast<OpTy>())
+ callback(op);
+ });
+ }
//===--------------------------------------------------------------------===//
// Arguments
diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h
deleted file mode 100644
index e11b7350894..00000000000
--- a/mlir/include/mlir/IR/InstVisitor.h
+++ /dev/null
@@ -1,140 +0,0 @@
-//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines the base classes for Function's instruction visitors and
-// walkers. A visit is a O(1) operation that visits just the node in question. A
-// walk visits the node it's called on as well as the node's descendants.
-//
-// Instruction visitors/walkers are used when you want to perform different
-// actions for different kinds of instructions without having to use lots of
-// casts and a big switch instruction.
-//
-// To define your own visitor/walker, inherit from these classes, specifying
-// your new type for the 'SubClass' template parameter, and "override" visitXXX
-// functions in your class. This class is defined in terms of statically
-// resolved overloading, not virtual functions.
-//
-// For example, here is a walker that counts the number of for loops in an
-// Function.
-//
-// /// Declare the class. Note that we derive from InstWalker instantiated
-// /// with _our new subclasses_ type.
-// struct LoopCounter : public InstWalker<LoopCounter> {
-// unsigned numLoops;
-// LoopCounter() : numLoops(0) {}
-// void visitForInst(ForInst &fs) { ++numLoops; }
-// };
-//
-// And this class would be used like this:
-// LoopCounter lc;
-// lc.walk(function);
-// numLoops = lc.numLoops;
-//
-// There are 'visit' methods for Instruction and Function, which recursively
-// process all contained instructions.
-//
-// Note that if you don't implement visitXXX for some instruction type,
-// the visitXXX method for Instruction superclass will be invoked.
-//
-// The optional second template argument specifies the type that instruction
-// visitation functions should return. If you specify this, you *MUST* provide
-// an implementation of every visit<#Instruction>(InstType *).
-//
-// Note that these classes are specifically designed as a template to avoid
-// virtual function call overhead.
-
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_INSTVISITOR_H
-#define MLIR_IR_INSTVISITOR_H
-
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Instruction.h"
-
-namespace mlir {
-/// Base class for instruction walkers. A walker can traverse depth first in
-/// pre-order or post order. The walk methods without a suffix do a pre-order
-/// traversal while those that traverse in post order have a PostOrder suffix.
-template <typename SubClass, typename RetTy = void> class InstWalker {
- //===--------------------------------------------------------------------===//
- // Interface code - This is the public interface of the InstWalker used to
- // walk instructions.
-
-public:
- // Generic walk method - allow walk to all instructions in a range.
- template <class Iterator> void walk(Iterator Start, Iterator End) {
- while (Start != End) {
- walk(&(*Start++));
- }
- }
- template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
- while (Start != End) {
- walkPostOrder(&(*Start++));
- }
- }
-
- // Define walkers for Function and all Function instruction kinds.
- void walk(Function *f) {
- for (auto &block : *f)
- static_cast<SubClass *>(this)->walk(block.begin(), block.end());
- }
-
- void walkPostOrder(Function *f) {
- for (auto it = f->rbegin(), e = f->rend(); it != e; ++it)
- static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end());
- }
-
- // Function to walk a instruction.
- RetTy walk(Instruction *s) {
- static_assert(std::is_base_of<InstWalker, SubClass>::value,
- "Must pass the derived type to this template!");
-
- static_cast<SubClass *>(this)->visitInstruction(s);
- for (auto &blockList : s->getBlockLists())
- for (auto &block : blockList)
- static_cast<SubClass *>(this)->walk(block.begin(), block.end());
- }
-
- // Function to walk a instruction in post order DFS.
- RetTy walkPostOrder(Instruction *s) {
- static_assert(std::is_base_of<InstWalker, SubClass>::value,
- "Must pass the derived type to this template!");
- for (auto &blockList : s->getBlockLists())
- for (auto &block : blockList)
- static_cast<SubClass *>(this)->walkPostOrder(block.begin(),
- block.end());
- static_cast<SubClass *>(this)->visitInstruction(s);
- }
-
- //===--------------------------------------------------------------------===//
- // Visitation functions... these functions provide default fallbacks in case
- // the user does not specify what to do for a particular instruction type.
- // The default behavior is to generalize the instruction type to its subtype
- // and try visiting the subtype. All of this should be inlined perfectly,
- // because there are no virtual functions to get in the way.
-
- // When visiting a specific inst directly during a walk, these methods get
- // called. These are typically O(1) complexity and shouldn't be recursively
- // processing their descendants in some way. When using RetTy, all of these
- // need to be overridden.
- void visitInstruction(Instruction *inst) {}
-};
-
-} // end namespace mlir
-
-#endif // MLIR_IR_INSTVISITOR_H
diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h
index c8a1dc8a7bb..bbd0ba10d65 100644
--- a/mlir/include/mlir/IR/Instruction.h
+++ b/mlir/include/mlir/IR/Instruction.h
@@ -614,6 +614,36 @@ public:
}
//===--------------------------------------------------------------------===//
+ // Instruction Walkers
+ //===--------------------------------------------------------------------===//
+
+ /// Walk the instructions held by this instruction in preorder, calling the
+ /// callback for each instruction.
+ void walk(const std::function<void(Instruction *)> &callback);
+
+ /// Specialization of walk to only visit operations of 'OpTy'.
+ template <typename OpTy>
+ void walk(std::function<void(OpPointer<OpTy>)> callback) {
+ walk([&](Instruction *inst) {
+ if (auto op = inst->dyn_cast<OpTy>())
+ callback(op);
+ });
+ }
+
+ /// Walk the instructions held by this function in postorder, calling the
+ /// callback for each instruction.
+ void walkPostOrder(const std::function<void(Instruction *)> &callback);
+
+ /// Specialization of walkPostOrder to only visit operations of 'OpTy'.
+ template <typename OpTy>
+ void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) {
+ walkPostOrder([&](Instruction *inst) {
+ if (auto op = inst->dyn_cast<OpTy>())
+ callback(op);
+ });
+ }
+
+ //===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 39345d7fc7a..c3adf5fb7c3 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -19,7 +19,6 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@@ -646,32 +645,6 @@ bool AffineForOp::matchingBoundOperandList() const {
return true;
}
-void AffineForOp::walk(std::function<void(Instruction *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(Instruction *)> const &callback;
- Walker(std::function<void(Instruction *)> const &callback)
- : callback(callback) {}
-
- void visitInstruction(Instruction *opInst) { callback(opInst); }
- };
-
- Walker w(callback);
- w.walk(getInstruction());
-}
-
-void AffineForOp::walkPostOrder(std::function<void(Instruction *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(Instruction *)> const &callback;
- Walker(std::function<void(Instruction *)> const &callback)
- : callback(callback) {}
-
- void visitInstruction(Instruction *opInst) { callback(opInst); }
- };
-
- Walker v(callback);
- v.walkPostOrder(getInstruction());
-}
-
/// Returns the induction variable for this loop.
Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); }
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
index ab22f261a3b..3376cd7d512 100644
--- a/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -26,7 +26,6 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@@ -38,13 +37,11 @@ using namespace mlir;
namespace {
/// Checks for out of bound memef access subscripts..
-struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
+struct MemRefBoundCheck : public FunctionPass {
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
PassResult runOnFunction(Function *f) override;
- void visitInstruction(Instruction *opInst);
-
static char passID;
};
@@ -56,17 +53,16 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
-void MemRefBoundCheck::visitInstruction(Instruction *opInst) {
- if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
- boundCheckLoadOrStoreOp(loadOp);
- } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
- boundCheckLoadOrStoreOp(storeOp);
- }
- // TODO(bondhugula): do this for DMA ops as well.
-}
-
PassResult MemRefBoundCheck::runOnFunction(Function *f) {
- return walk(f), success();
+ f->walk([](Instruction *opInst) {
+ if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
+ boundCheckLoadOrStoreOp(loadOp);
+ } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
+ boundCheckLoadOrStoreOp(storeOp);
+ }
+ // TODO(bondhugula): do this for DMA ops as well.
+ });
+ return success();
}
static PassRegistration<MemRefBoundCheck>
diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
index 6ea47a20f60..9ec1c95f213 100644
--- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
@@ -25,7 +25,6 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@@ -38,19 +37,13 @@ namespace {
// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
/// Checks dependences between all pairs of memref accesses in a Function.
-struct MemRefDependenceCheck : public FunctionPass,
- InstWalker<MemRefDependenceCheck> {
+struct MemRefDependenceCheck : public FunctionPass {
SmallVector<Instruction *, 4> loadsAndStores;
explicit MemRefDependenceCheck()
: FunctionPass(&MemRefDependenceCheck::passID) {}
PassResult runOnFunction(Function *f) override;
- void visitInstruction(Instruction *opInst) {
- if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
- loadsAndStores.push_back(opInst);
- }
- }
static char passID;
};
@@ -120,8 +113,13 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) {
// Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
// Runs pair-wise dependence checks.
PassResult MemRefDependenceCheck::runOnFunction(Function *f) {
+ // Collect the loads and stores within the function.
loadsAndStores.clear();
- walk(f);
+ f->walk([&](Instruction *inst) {
+ if (inst->isa<LoadOp>() || inst->isa<StoreOp>())
+ loadsAndStores.push_back(inst);
+ });
+
checkDependences(loadsAndStores);
return success();
}
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index 742c0baa96b..f05f8737b16 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -15,7 +15,6 @@
// limitations under the License.
// =============================================================================
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
@@ -27,16 +26,13 @@
using namespace mlir;
namespace {
-struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> {
+struct PrintOpStatsPass : public ModulePass {
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
: ModulePass(&PrintOpStatsPass::passID), os(os) {}
// Prints the resultant operation statistics post iterating over the module.
PassResult runOnModule(Module *m) override;
- // Updates the operation statistics for the given instruction.
- void visitInstruction(Instruction *inst);
-
// Print summary of op stats.
void printSummary();
@@ -44,7 +40,6 @@ struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> {
private:
llvm::StringMap<int64_t> opCount;
-
llvm::raw_ostream &os;
};
} // namespace
@@ -52,16 +47,16 @@ private:
char PrintOpStatsPass::passID = 0;
PassResult PrintOpStatsPass::runOnModule(Module *m) {
+ opCount.clear();
+
+ // Compute the operation statistics for each function in the module.
for (auto &fn : *m)
- walk(&fn);
+ fn.walk(
+ [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; });
printSummary();
return success();
}
-void PrintOpStatsPass::visitInstruction(Instruction *inst) {
- ++opCount[inst->getName().getStringRef()];
-}
-
void PrintOpStatsPass::printSummary() {
os << "Operations encountered:\n";
os << "-----------------------\n";
diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp
index 1703a16c2b8..cea99121dc9 100644
--- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp
+++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp
@@ -19,7 +19,6 @@
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a69920cbd86..ffc863d76d0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -25,7 +25,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index f18ce8e33a8..e6dfc4c2145 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -18,7 +18,6 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
using namespace mlir;
@@ -227,6 +226,34 @@ Block *Block::getSinglePredecessor() {
}
//===----------------------------------------------------------------------===//
+// Instruction Walkers
+//===----------------------------------------------------------------------===//
+
+void Block::walk(const std::function<void(Instruction *)> &callback) {
+ walk(begin(), end(), callback);
+}
+
+void Block::walk(Block::iterator begin, Block::iterator end,
+ const std::function<void(Instruction *)> &callback) {
+ // Walk the instructions within this block.
+ for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end)))
+ inst.walk(callback);
+}
+
+void Block::walkPostOrder(const std::function<void(Instruction *)> &callback) {
+ walkPostOrder(begin(), end(), callback);
+}
+
+/// Walk the instructions in the specified [begin, end) range of this block
+/// in postorder, calling the callback for each operation.
+void Block::walkPostOrder(Block::iterator begin, Block::iterator end,
+ const std::function<void(Instruction *)> &callback) {
+ // Walk the instructions within this block.
+ for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end)))
+ inst.walkPostOrder(callback);
+}
+
+//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//
@@ -253,37 +280,6 @@ Block *Block::splitBlock(iterator splitBefore) {
return newBB;
}
-void Block::walk(std::function<void(Instruction *)> callback) {
- walk(begin(), end(), callback);
-}
-
-void Block::walk(Block::iterator begin, Block::iterator end,
- std::function<void(Instruction *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(Instruction *)> const &callback;
- Walker(std::function<void(Instruction *)> const &callback)
- : callback(callback) {}
-
- void visitInstruction(Instruction *opInst) { callback(opInst); }
- };
-
- Walker w(callback);
- w.walk(begin, end);
-}
-
-void Block::walkPostOrder(std::function<void(Instruction *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(Instruction *)> const &callback;
- Walker(std::function<void(Instruction *)> const &callback)
- : callback(callback) {}
-
- void visitInstruction(Instruction *opInst) { callback(opInst); }
- };
-
- Walker v(callback);
- v.walkPostOrder(begin(), end());
-}
-
//===----------------------------------------------------------------------===//
// BlockList
//===----------------------------------------------------------------------===//
@@ -331,25 +327,18 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper,
// Now that each of the blocks have been cloned, go through and remap the
// operands of each of the instructions.
- struct Walker : public InstWalker<Walker> {
- BlockAndValueMapping &mapper;
- Walker(BlockAndValueMapping &mapper) : mapper(mapper) {}
-
- /// Remap the instruction and successor block operands.
- void visitInstruction(Instruction *inst) {
- for (auto &instOp : inst->getInstOperands())
- if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
- instOp.set(mappedOp);
- if (inst->isTerminator())
- for (auto &succOp : inst->getBlockOperands())
- if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
- succOp.set(mappedOp);
- }
+ auto remapOperands = [&](Instruction *inst) {
+ for (auto &instOp : inst->getInstOperands())
+ if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
+ instOp.set(mappedOp);
+ if (inst->isTerminator())
+ for (auto &succOp : inst->getBlockOperands())
+ if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
+ succOp.set(mappedOp);
};
- Walker v(mapper);
for (auto it = std::next(lastOldBlock), e = dest->end(); it != e; ++it)
- v.walk(it->begin(), it->end());
+ it->walk(remapOperands);
}
BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() {
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 3a263fb13f9..ba781500c4f 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -19,7 +19,6 @@
#include "AttributeListStorage.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
@@ -214,28 +213,15 @@ void Function::addEntryBlock() {
entry->addArguments(type.getInputs());
}
-void Function::walk(std::function<void(Instruction *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(Instruction *)> const &callback;
- Walker(std::function<void(Instruction *)> const &callback)
- : callback(callback) {}
-
- void visitInstruction(Instruction *inst) { callback(inst); }
- };
-
- Walker v(callback);
- v.walk(this);
+void Function::walk(const std::function<void(Instruction *)> &callback) {
+ // Walk each of the blocks within the function.
+ for (auto &block : getBlocks())
+ block.walk(callback);
}
-void Function::walkPostOrder(std::function<void(Instruction *)> callback) {
- struct Walker : public InstWalker<Walker> {
- std::function<void(Instruction *)> const &callback;
- Walker(std::function<void(Instruction *)> const &callback)
- : callback(callback) {}
-
- void visitInstruction(Instruction *inst) { callback(inst); }
- };
-
- Walker v(callback);
- v.walkPostOrder(this);
+void Function::walkPostOrder(
+ const std::function<void(Instruction *)> &callback) {
+ // Walk each of the blocks within the function.
+ for (auto &block : llvm::reverse(getBlocks()))
+ block.walkPostOrder(callback);
}
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp
index 6720969ac0f..062f13a3282 100644
--- a/mlir/lib/IR/Instruction.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -22,7 +22,6 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/DenseMap.h"
@@ -300,6 +299,35 @@ Function *Instruction::getFunction() const {
return block ? block->getFunction() : nullptr;
}
+//===----------------------------------------------------------------------===//
+// Instruction Walkers
+//===----------------------------------------------------------------------===//
+
+void Instruction::walk(const std::function<void(Instruction *)> &callback) {
+ // Visit the current instruction.
+ callback(this);
+
+ // Visit any internal instructions.
+ for (auto &blockList : getBlockLists())
+ for (auto &block : blockList)
+ block.walk(callback);
+}
+
+void Instruction::walkPostOrder(
+ const std::function<void(Instruction *)> &callback) {
+ // Visit any internal instructions.
+ for (auto &blockList : llvm::reverse(getBlockLists()))
+ for (auto &block : llvm::reverse(blockList))
+ block.walkPostOrder(callback);
+
+ // Visit the current instruction.
+ callback(this);
+}
+
+//===----------------------------------------------------------------------===//
+// Other
+//===----------------------------------------------------------------------===//
+
/// Emit a note about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
void Instruction::emitNote(const Twine &message) const {
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index ae77d66b183..b7e4fb147cb 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -26,7 +26,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 63a676d7b52..de10fe8a461 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -24,7 +24,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp
index 289b00d3b51..796477c64f2 100644
--- a/mlir/lib/Transforms/ComposeAffineMaps.cpp
+++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp
@@ -27,7 +27,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@@ -46,10 +45,9 @@ namespace {
// result of any AffineApplyOp). After this composition, AffineApplyOps with no
// remaining uses are erased.
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
-struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
+struct ComposeAffineMaps : public FunctionPass {
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
PassResult runOnFunction(Function *f) override;
- void visitInstruction(Instruction *opInst);
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
@@ -68,15 +66,11 @@ static bool affineApplyOp(const Instruction &inst) {
return inst.isa<AffineApplyOp>();
}
-void ComposeAffineMaps::visitInstruction(Instruction *opInst) {
- if (auto afOp = opInst->dyn_cast<AffineApplyOp>())
- affineApplyOps.push_back(afOp);
-}
-
PassResult ComposeAffineMaps::runOnFunction(Function *f) {
// If needed for future efficiency, reserve space based on a pre-walk.
affineApplyOps.clear();
- walk(f);
+ f->walk<AffineApplyOp>(
+ [&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); });
for (auto afOp : affineApplyOps) {
SmallVector<Value *, 8> operands(afOp->getOperands());
FuncBuilder b(afOp->getInstruction());
@@ -87,7 +81,8 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) {
// Erase dead affine apply ops.
affineApplyOps.clear();
- walk(f);
+ f->walk<AffineApplyOp>(
+ [&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); });
for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) {
if ((*it)->use_empty()) {
(*it)->erase();
diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp
index 54486cdb293..e41ac0ad329 100644
--- a/mlir/lib/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Transforms/ConstantFold.cpp
@@ -18,7 +18,6 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@@ -27,7 +26,7 @@ using namespace mlir;
namespace {
/// Simple constant folding pass.
-struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
+struct ConstantFold : public FunctionPass {
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
// All constants in the function post folding.
@@ -35,9 +34,7 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
// Operations that were folded and that need to be erased.
std::vector<Instruction *> opInstsToErase;
- bool foldOperation(Instruction *op,
- SmallVectorImpl<Value *> &existingConstants);
- void visitInstruction(Instruction *op);
+ void foldInstruction(Instruction *op);
PassResult runOnFunction(Function *f) override;
static char passID;
@@ -49,7 +46,7 @@ char ConstantFold::passID = 0;
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
-void ConstantFold::visitInstruction(Instruction *op) {
+void ConstantFold::foldInstruction(Instruction *op) {
// If this operation is an AffineForOp, then fold the bounds.
if (auto forOp = op->dyn_cast<AffineForOp>()) {
constantFoldBounds(forOp);
@@ -111,7 +108,7 @@ PassResult ConstantFold::runOnFunction(Function *f) {
existingConstants.clear();
opInstsToErase.clear();
- walk(f);
+ f->walk([&](Instruction *inst) { foldInstruction(inst); });
// At this point, these operations are dead, remove them.
// TODO: This is assuming that all constant foldable operations have no
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 7a002168528..77e5a6aa04f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -28,7 +28,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@@ -111,22 +110,23 @@ namespace {
// LoopNestStateCollector walks loop nests and collects load and store
// operations, and whether or not an IfInst was encountered in the loop nest.
-class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
-public:
+struct LoopNestStateCollector {
SmallVector<OpPointer<AffineForOp>, 4> forOps;
SmallVector<Instruction *, 4> loadOpInsts;
SmallVector<Instruction *, 4> storeOpInsts;
bool hasNonForRegion = false;
- void visitInstruction(Instruction *opInst) {
- if (opInst->isa<AffineForOp>())
- forOps.push_back(opInst->cast<AffineForOp>());
- else if (opInst->getNumBlockLists() != 0)
- hasNonForRegion = true;
- else if (opInst->isa<LoadOp>())
- loadOpInsts.push_back(opInst);
- else if (opInst->isa<StoreOp>())
- storeOpInsts.push_back(opInst);
+ void collect(Instruction *instToWalk) {
+ instToWalk->walk([&](Instruction *opInst) {
+ if (opInst->isa<AffineForOp>())
+ forOps.push_back(opInst->cast<AffineForOp>());
+ else if (opInst->getNumBlockLists() != 0)
+ hasNonForRegion = true;
+ else if (opInst->isa<LoadOp>())
+ loadOpInsts.push_back(opInst);
+ else if (opInst->isa<StoreOp>())
+ storeOpInsts.push_back(opInst);
+ });
}
};
@@ -510,7 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
- collector.walk(&inst);
+ collector.collect(&inst);
// Return false if a non 'for' region was found (not currently supported).
if (collector.hasNonForRegion)
return false;
@@ -606,41 +606,39 @@ struct LoopNestStats {
// LoopNestStatsCollector walks a single loop nest and gathers per-loop
// trip count and operation count statistics and records them in 'stats'.
-class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
-public:
+struct LoopNestStatsCollector {
LoopNestStats *stats;
bool hasLoopWithNonConstTripCount = false;
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
- void visitInstruction(Instruction *opInst) {
- auto forOp = opInst->dyn_cast<AffineForOp>();
- if (!forOp)
- return;
-
- auto *forInst = forOp->getInstruction();
- auto *parentInst = forOp->getInstruction()->getParentInst();
- if (parentInst != nullptr) {
- assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
- // Add mapping to 'forOp' from its parent AffineForOp.
- stats->loopMap[parentInst].push_back(forOp);
- }
+ void collect(Instruction *inst) {
+ inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
+ auto *forInst = forOp->getInstruction();
+ auto *parentInst = forOp->getInstruction()->getParentInst();
+ if (parentInst != nullptr) {
+ assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
+ // Add mapping to 'forOp' from its parent AffineForOp.
+ stats->loopMap[parentInst].push_back(forOp);
+ }
- // Record the number of op instructions in the body of 'forOp'.
- unsigned count = 0;
- stats->opCountMap[forInst] = 0;
- for (auto &inst : *forOp->getBody()) {
- if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
- ++count;
- }
- stats->opCountMap[forInst] = count;
- // Record trip count for 'forOp'. Set flag if trip count is not constant.
- Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
- if (!maybeConstTripCount.hasValue()) {
- hasLoopWithNonConstTripCount = true;
- return;
- }
- stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
+ // Record the number of op instructions in the body of 'forOp'.
+ unsigned count = 0;
+ stats->opCountMap[forInst] = 0;
+ for (auto &inst : *forOp->getBody()) {
+ if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
+ ++count;
+ }
+ stats->opCountMap[forInst] = count;
+ // Record trip count for 'forOp'. Set flag if trip count is not
+ // constant.
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+ if (!maybeConstTripCount.hasValue()) {
+ hasLoopWithNonConstTripCount = true;
+ return;
+ }
+ stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
+ });
}
};
@@ -1078,7 +1076,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
- srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
+ srcStatsCollector.collect(srcLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (srcStatsCollector.hasLoopWithNonConstTripCount)
return false;
@@ -1089,7 +1087,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
LoopNestStats dstLoopNestStats;
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
- dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
+ dstStatsCollector.collect(dstLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (dstStatsCollector.hasLoopWithNonConstTripCount)
return false;
@@ -1474,7 +1472,7 @@ public:
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
- sliceCollector.walk(sliceLoopNest->getInstruction());
+ sliceCollector.collect(sliceLoopNest->getInstruction());
// Promote single iteration slice loops to single IV value.
for (auto forOp : sliceCollector.forOps) {
promoteIfSingleIteration(forOp);
@@ -1498,7 +1496,7 @@ public:
// Collect dst loop stats after memref privatizaton transformation.
LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.walk(dstAffineForOp->getInstruction());
+ dstLoopCollector.collect(dstAffineForOp->getInstruction());
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index b1e15ccb07b..3a7cfb85e08 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -27,7 +27,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -95,15 +94,16 @@ char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnFunction(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
- public:
+ struct InnermostLoopGatherer {
// Store innermost loops as we walk.
std::vector<OpPointer<AffineForOp>> loops;
- // This method specialized to encode custom return logic.
- using InstListType = llvm::iplist<Instruction>;
- bool walkPostOrder(InstListType::iterator Start,
- InstListType::iterator End) {
+ void walkPostOrder(Function *f) {
+ for (auto &b : *f)
+ walkPostOrder(b.begin(), b.end());
+ }
+
+ bool walkPostOrder(Block::iterator Start, Block::iterator End) {
bool hasInnerLoops = false;
// We need to walk all elements since all innermost loops need to be
// gathered as opposed to determining whether this list has any inner
@@ -112,7 +112,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
hasInnerLoops |= walkPostOrder(&(*Start++));
return hasInnerLoops;
}
-
bool walkPostOrder(Instruction *opInst) {
bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
@@ -125,39 +124,21 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
}
return hasInnerLoops;
}
-
- // FIXME: can't use base class method for this because that in turn would
- // need to use the derived class method above. CRTP doesn't allow it, and
- // the compiler error resulting from it is also misleading.
- using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
- // Gathers all loops with trip count <= minTripCount.
- class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
- public:
+ if (clUnrollFull.getNumOccurrences() > 0 &&
+ clUnrollFullThreshold.getNumOccurrences() > 0) {
// Store short loops as we walk.
std::vector<OpPointer<AffineForOp>> loops;
- const unsigned minTripCount;
- ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- void visitInstruction(Instruction *opInst) {
- auto forOp = opInst->dyn_cast<AffineForOp>();
- if (!forOp)
- return;
+ // Gathers all loops with trip count <= minTripCount. Do a post order walk
+ // so that loops are gathered from innermost to outermost (or else unrolling
+ // an outer one may delete gathered inner ones).
+ f->walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
- if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
+ if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
loops.push_back(forOp);
- }
- };
-
- if (clUnrollFull.getNumOccurrences() > 0 &&
- clUnrollFullThreshold.getNumOccurrences() > 0) {
- ShortLoopGatherer slg(clUnrollFullThreshold);
- // Do a post order walk so that loops are gathered from innermost to
- // outermost (or else unrolling an outer one may delete gathered inner
- // ones).
- slg.walkPostOrder(f);
- auto &loops = slg.loops;
+ });
for (auto forOp : loops)
loopUnrollFull(forOp);
return success();
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 74c54fde047..b2aed7d9d7f 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -50,7 +50,6 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -136,24 +135,25 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
// Gathers all maximal sub-blocks of instructions that do not themselves
// include a for inst (a instruction could have a descendant for inst though
// in its tree).
- class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
- public:
- using InstListType = llvm::iplist<Instruction>;
- using InstWalker<JamBlockGatherer>::walk;
-
+ struct JamBlockGatherer {
// Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
// This is a linear time walk.
- void walk(InstListType::iterator Start, InstListType::iterator End) {
- for (auto it = Start; it != End;) {
+ void walk(Instruction *inst) {
+ for (auto &blockList : inst->getBlockLists())
+ for (auto &block : blockList)
+ walk(block);
+ }
+ void walk(Block &block) {
+ for (auto it = block.begin(), e = block.end(); it != e;) {
auto subBlockStart = it;
- while (it != End && !it->isa<AffineForOp>())
+ while (it != e && !it->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
- while (it != End && it->isa<AffineForOp>())
+ while (it != e && it->isa<AffineForOp>())
walk(&*it++);
}
}
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index 2d06a327315..9c9db30d163 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -25,7 +25,6 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Utils.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@@ -70,12 +69,12 @@ namespace {
// currently only eliminates the stores only if no other loads/uses (other
// than dealloc) remain.
//
-struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
+struct MemRefDataFlowOpt : public FunctionPass {
explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {}
PassResult runOnFunction(Function *f) override;
- void visitInstruction(Instruction *opInst);
+ void forwardStoreToLoad(OpPointer<LoadOp> loadOp);
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value *, 4> memrefsToErase;
@@ -100,14 +99,9 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
// This is a straightforward implementation not optimized for speed. Optimize
// this in the future if needed.
-void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) {
+void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer<LoadOp> loadOp) {
Instruction *lastWriteStoreOp = nullptr;
-
- auto loadOp = opInst->dyn_cast<LoadOp>();
- if (!loadOp)
- return;
-
- Instruction *loadOpInst = opInst;
+ Instruction *loadOpInst = loadOp->getInstruction();
// First pass over the use list to get minimum number of surrounding
// loops common between the load op and the store op, with min taken across
@@ -235,7 +229,8 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) {
memrefsToErase.clear();
// Walk all load's and perform load/store forwarding.
- walk(f);
+ f->walk<LoadOp>(
+ [&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); });
// Erase all load op's whose results were replaced with store fwd'ed ones.
for (auto *loadOp : loadOpsToErase) {
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index ba3be5e95f4..4ca48a53485 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -142,10 +142,8 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
- f->walkPostOrder([&](Instruction *opInst) {
- if (auto forOp = opInst->dyn_cast<AffineForOp>())
- forOps.push_back(forOp);
- });
+ f->walkPostOrder<AffineForOp>(
+ [&](OpPointer<AffineForOp> forOp) { forOps.push_back(forOp); });
bool ret = false;
for (auto forOp : forOps) {
ret = ret | runOnAffineForOp(forOp);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 5bf17989bef..95875adca6e 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -28,7 +28,6 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
@@ -135,10 +134,8 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- f->walkPostOrder([](Instruction *inst) {
- if (auto forOp = inst->dyn_cast<AffineForOp>())
- promoteIfSingleIteration(forOp);
- });
+ f->walkPostOrder<AffineForOp>(
+ [](OpPointer<AffineForOp> forOp) { promoteIfSingleIteration(forOp); });
}
/// Generates a 'for' inst with the specified lower and upper bounds while
OpenPOWER on IntegriCloud