diff options
Diffstat (limited to 'mlir/lib/Analysis')
| -rw-r--r-- | mlir/lib/Analysis/MemRefBoundCheck.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/Analysis/MemRefDependenceCheck.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/Analysis/OpStats.cpp | 17 |
3 files changed, 23 insertions, 34 deletions
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index ab22f261a3b..3376cd7d512 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -26,7 +26,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,13 +37,11 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> { +struct MemRefBoundCheck : public FunctionPass { explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); - static char passID; }; @@ -56,17 +53,16 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitInstruction(Instruction *opInst) { - if (auto loadOp = opInst->dyn_cast<LoadOp>()) { - boundCheckLoadOrStoreOp(loadOp); - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { - boundCheckLoadOrStoreOp(storeOp); - } - // TODO(bondhugula): do this for DMA ops as well. -} - PassResult MemRefBoundCheck::runOnFunction(Function *f) { - return walk(f), success(); + f->walk([](Instruction *opInst) { + if (auto loadOp = opInst->dyn_cast<LoadOp>()) { + boundCheckLoadOrStoreOp(loadOp); + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + boundCheckLoadOrStoreOp(storeOp); + } + // TODO(bondhugula): do this for DMA ops as well. + }); + return success(); } static PassRegistration<MemRefBoundCheck> diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 6ea47a20f60..9ec1c95f213 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/Support/Debug.h" @@ -38,19 +37,13 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. -struct MemRefDependenceCheck : public FunctionPass, - InstWalker<MemRefDependenceCheck> { +struct MemRefDependenceCheck : public FunctionPass { SmallVector<Instruction *, 4> loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst) { - if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) { - loadsAndStores.push_back(opInst); - } - } static char passID; }; @@ -120,8 +113,13 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. PassResult MemRefDependenceCheck::runOnFunction(Function *f) { + // Collect the loads and stores within the function. loadsAndStores.clear(); - walk(f); + f->walk([&](Instruction *inst) { + if (inst->isa<LoadOp>() || inst->isa<StoreOp>()) + loadsAndStores.push_back(inst); + }); + checkDependences(loadsAndStores); return success(); } diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 742c0baa96b..f05f8737b16 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -15,7 +15,6 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" @@ -27,16 +26,13 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> { +struct PrintOpStatsPass : public ModulePass { explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : ModulePass(&PrintOpStatsPass::passID), os(os) {} // Prints the resultant operation statistics post iterating over the module. PassResult runOnModule(Module *m) override; - // Updates the operation statistics for the given instruction. - void visitInstruction(Instruction *inst); - // Print summary of op stats. void printSummary(); @@ -44,7 +40,6 @@ struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> { private: llvm::StringMap<int64_t> opCount; - llvm::raw_ostream &os; }; } // namespace @@ -52,16 +47,16 @@ private: char PrintOpStatsPass::passID = 0; PassResult PrintOpStatsPass::runOnModule(Module *m) { + opCount.clear(); + + // Compute the operation statistics for each function in the module. for (auto &fn : *m) - walk(&fn); + fn.walk( + [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); printSummary(); return success(); } -void PrintOpStatsPass::visitInstruction(Instruction *inst) { - ++opCount[inst->getName().getStringRef()]; -} - void PrintOpStatsPass::printSummary() { os << "Operations encountered:\n"; os << "-----------------------\n"; |

