diff options
Diffstat (limited to 'mlir/lib/Pass/IRPrinting.cpp')
-rw-r--r-- | mlir/lib/Pass/IRPrinting.cpp | 43 |
1 files changed, 20 insertions, 23 deletions
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 2de4b05a36c..bc661979e59 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -27,8 +27,8 @@ using namespace mlir::detail; namespace { class IRPrinterInstrumentation : public PassInstrumentation { public: - /// A filter function to decide if the given ir should be printed. Returns - /// true if the ir should be printed, false otherwise. + /// A filter function to decide if the given pass should be printed. Returns + /// true if the pass should be printed, false otherwise. using ShouldPrintFn = std::function<bool(Pass *)>; IRPrinterInstrumentation(ShouldPrintFn &&shouldPrintBeforePass, @@ -43,9 +43,9 @@ public: private: /// Instrumentation hooks. - void runBeforePass(Pass *pass, const llvm::Any &ir) override; - void runAfterPass(Pass *pass, const llvm::Any &ir) override; - void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override; + void runBeforePass(Pass *pass, Operation *op) override; + void runAfterPass(Pass *pass, Operation *op) override; + void runAfterPassFailed(Pass *pass, Operation *op) override; /// Filter functions for before and after pass execution. ShouldPrintFn shouldPrintBeforePass, shouldPrintAfterPass; @@ -63,12 +63,10 @@ static bool isHiddenPass(Pass *pass) { return isAdaptorPass(pass) || isVerifierPass(pass); } -static void printIR(const llvm::Any &ir, bool printModuleScope, - raw_ostream &out) { +static void printIR(Operation *op, bool printModuleScope, raw_ostream &out) { // Check for printing at module scope. - if (printModuleScope && llvm::any_isa<FuncOp>(ir)) { - FuncOp function = llvm::any_cast<FuncOp>(ir); - + auto function = dyn_cast<FuncOp>(op); + if (printModuleScope && function) { // Print the function name and a newline before the Module. out << " (function: " << function.getName() << ")\n"; function.getParentOfType<ModuleOp>().print(out); @@ -79,45 +77,44 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, out << "\n"; // Print the given function. - if (llvm::any_isa<FuncOp>(ir)) { - llvm::any_cast<FuncOp>(ir).print(out); + if (function) { + function.print(out); return; } // Print the given module. - assert(llvm::any_isa<ModuleOp>(ir) && "unexpected IR unit"); - llvm::any_cast<ModuleOp>(ir).print(out); + assert(isa<ModuleOp>(op) && "unexpected IR unit"); + cast<ModuleOp>(op).print(out); } /// Instrumentation hooks. -void IRPrinterInstrumentation::runBeforePass(Pass *pass, const llvm::Any &ir) { - // Skip adaptor passes and passes that the user filtered out. +void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { + // Skip hidden passes and passes that the user filtered out. if (!shouldPrintBeforePass || isHiddenPass(pass) || !shouldPrintBeforePass(pass)) return; out << formatv("*** IR Dump Before {0} ***", pass->getName()); - printIR(ir, printModuleScope, out); + printIR(op, printModuleScope, out); out << "\n\n"; } -void IRPrinterInstrumentation::runAfterPass(Pass *pass, const llvm::Any &ir) { - // Skip adaptor passes and passes that the user filtered out. +void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { + // Skip hidden passes and passes that the user filtered out. if (!shouldPrintAfterPass || isHiddenPass(pass) || !shouldPrintAfterPass(pass)) return; out << formatv("*** IR Dump After {0} ***", pass->getName()); - printIR(ir, printModuleScope, out); + printIR(op, printModuleScope, out); out << "\n\n"; } -void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, - const llvm::Any &ir) { +void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { // Skip adaptor passes and passes that the user filtered out. if (!shouldPrintAfterPass || isAdaptorPass(pass) || !shouldPrintAfterPass(pass)) return; out << formatv("*** IR Dump After {0} Failed ***", pass->getName()); - printIR(ir, printModuleScope, out); + printIR(op, printModuleScope, out); out << "\n\n"; } |