diff options
-rw-r--r-- | mlir/g3doc/WritingAPass.md | 3 | ||||
-rw-r--r-- | mlir/include/mlir/Pass/PassInstrumentation.h | 73 | ||||
-rw-r--r-- | mlir/lib/Pass/IRPrinting.cpp | 43 | ||||
-rw-r--r-- | mlir/lib/Pass/Pass.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/Pass/PassTiming.cpp | 20 |
5 files changed, 61 insertions, 98 deletions
diff --git a/mlir/g3doc/WritingAPass.md b/mlir/g3doc/WritingAPass.md index dc06ace2519..47e57df6b77 100644 --- a/mlir/g3doc/WritingAPass.md +++ b/mlir/g3doc/WritingAPass.md @@ -389,8 +389,7 @@ struct DominanceCounterInstrumentation : public PassInstrumentation { unsigned &count; DominanceCounterInstrumentation(unsigned &count) : count(count) {} - void runAfterAnalysis(llvm::StringRef, AnalysisID *id, - const llvm::Any &) override { + void runAfterAnalysis(llvm::StringRef, AnalysisID *id, Operation *) override { if (id == AnalysisID::getID<DominanceInfo>()) ++count; } diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h index 40358329f45..46df6fdd877 100644 --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -20,11 +20,11 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" -#include "llvm/ADT/Any.h" #include "llvm/ADT/StringRef.h" namespace mlir { using AnalysisID = ClassID; +class Operation; class Pass; namespace detail { @@ -39,32 +39,32 @@ public: virtual ~PassInstrumentation() = 0; /// A callback to run before a pass is executed. This function takes a pointer - /// to the pass to be executed, as well as an llvm::Any holding a pointer to - /// the IR unit being transformed on. - virtual void runBeforePass(Pass *pass, const llvm::Any &ir) {} + /// to the pass to be executed, as well as the current operation being + /// operated on. + virtual void runBeforePass(Pass *pass, Operation *op) {} /// A callback to run after a pass is successfully executed. This function - /// takes a pointer to the pass to be executed, as well as an llvm::Any - /// holding a pointer to the IR unit being transformed on. - virtual void runAfterPass(Pass *pass, const llvm::Any &ir) {} + /// takes a pointer to the pass to be executed, as well as the current + /// operation being operated on. + virtual void runAfterPass(Pass *pass, Operation *op) {} /// A callback to run when a pass execution fails. This function takes a - /// pointer to the pass that was being executed, as well as an llvm::Any - /// holding a pointer to the IR unit that was being transformed. Note - /// that the ir unit may be in an invalid state. - virtual void runAfterPassFailed(Pass *pass, const llvm::Any &ir) {} + /// pointer to the pass that was being executed, as well as the current + /// operation being operated on. Note that the operation may be in an invalid + /// state. + virtual void runAfterPassFailed(Pass *pass, Operation *op) {} /// A callback to run before an analysis is computed. This function takes the - /// name of the analysis to be computed, its AnalysisID, as well as an - /// llvm::Any holding a pointer to the IR unit being analyzed on. + /// name of the analysis to be computed, its AnalysisID, as well as the + /// current operation being analyzed. virtual void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &ir) {} + Operation *op) {} /// A callback to run before an analysis is computed. This function takes the - /// name of the analysis that was computed, its AnalysisID, as well as an - /// llvm::Any holding a pointer to the IR unit that was analyzed. + /// name of the analysis that was computed, its AnalysisID, as well as the + /// current operation being analyzed. virtual void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &ir) {} + Operation *op) {} }; /// This class holds a collection of PassInstrumentation objects, and invokes @@ -77,54 +77,25 @@ public: ~PassInstrumentor(); /// See PassInstrumentation::runBeforePass for details. - template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) { - runBeforePass(pass, llvm::Any(ir)); - } + void runBeforePass(Pass *pass, Operation *op); /// See PassInstrumentation::runAfterPass for details. - template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) { - runAfterPass(pass, llvm::Any(ir)); - } + void runAfterPass(Pass *pass, Operation *op); /// See PassInstrumentation::runAfterPassFailed for details. - template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) { - runAfterPassFailed(pass, llvm::Any(ir)); - } + void runAfterPassFailed(Pass *pass, Operation *op); /// See PassInstrumentation::runBeforeAnalysis for details. - template <typename IRUnitT> - void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) { - runBeforeAnalysis(name, id, llvm::Any(ir)); - } + void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op); /// See PassInstrumentation::runAfterAnalysis for details. - template <typename IRUnitT> - void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) { - runAfterAnalysis(name, id, llvm::Any(ir)); - } + void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op); /// Add the given instrumentation to the collection. This takes ownership over /// the given pointer. void addInstrumentation(PassInstrumentation *pi); private: - /// See PassInstrumentation::runBeforePass for details. - void runBeforePass(Pass *pass, const llvm::Any &ir); - - /// See PassInstrumentation::runAfterPass for details. - void runAfterPass(Pass *pass, const llvm::Any &ir); - - /// See PassInstrumentation::runAfterPassFailed for details. - void runAfterPassFailed(Pass *pass, const llvm::Any &ir); - - /// See PassInstrumentation::runBeforeAnalysis for details. - void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &ir); - - /// See PassInstrumentation::runAfterAnalysis for details. - void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &ir); - std::unique_ptr<detail::PassInstrumentorImpl> impl; }; diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 2de4b05a36c..bc661979e59 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -27,8 +27,8 @@ using namespace mlir::detail; namespace { class IRPrinterInstrumentation : public PassInstrumentation { public: - /// A filter function to decide if the given ir should be printed. Returns - /// true if the ir should be printed, false otherwise. + /// A filter function to decide if the given pass should be printed. Returns + /// true if the pass should be printed, false otherwise. using ShouldPrintFn = std::function<bool(Pass *)>; IRPrinterInstrumentation(ShouldPrintFn &&shouldPrintBeforePass, @@ -43,9 +43,9 @@ public: private: /// Instrumentation hooks. - void runBeforePass(Pass *pass, const llvm::Any &ir) override; - void runAfterPass(Pass *pass, const llvm::Any &ir) override; - void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override; + void runBeforePass(Pass *pass, Operation *op) override; + void runAfterPass(Pass *pass, Operation *op) override; + void runAfterPassFailed(Pass *pass, Operation *op) override; /// Filter functions for before and after pass execution. ShouldPrintFn shouldPrintBeforePass, shouldPrintAfterPass; @@ -63,12 +63,10 @@ static bool isHiddenPass(Pass *pass) { return isAdaptorPass(pass) || isVerifierPass(pass); } -static void printIR(const llvm::Any &ir, bool printModuleScope, - raw_ostream &out) { +static void printIR(Operation *op, bool printModuleScope, raw_ostream &out) { // Check for printing at module scope. - if (printModuleScope && llvm::any_isa<FuncOp>(ir)) { - FuncOp function = llvm::any_cast<FuncOp>(ir); - + auto function = dyn_cast<FuncOp>(op); + if (printModuleScope && function) { // Print the function name and a newline before the Module. out << " (function: " << function.getName() << ")\n"; function.getParentOfType<ModuleOp>().print(out); @@ -79,45 +77,44 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, out << "\n"; // Print the given function. - if (llvm::any_isa<FuncOp>(ir)) { - llvm::any_cast<FuncOp>(ir).print(out); + if (function) { + function.print(out); return; } // Print the given module. - assert(llvm::any_isa<ModuleOp>(ir) && "unexpected IR unit"); - llvm::any_cast<ModuleOp>(ir).print(out); + assert(isa<ModuleOp>(op) && "unexpected IR unit"); + cast<ModuleOp>(op).print(out); } /// Instrumentation hooks. -void IRPrinterInstrumentation::runBeforePass(Pass *pass, const llvm::Any &ir) { - // Skip adaptor passes and passes that the user filtered out. +void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { + // Skip hidden passes and passes that the user filtered out. if (!shouldPrintBeforePass || isHiddenPass(pass) || !shouldPrintBeforePass(pass)) return; out << formatv("*** IR Dump Before {0} ***", pass->getName()); - printIR(ir, printModuleScope, out); + printIR(op, printModuleScope, out); out << "\n\n"; } -void IRPrinterInstrumentation::runAfterPass(Pass *pass, const llvm::Any &ir) { - // Skip adaptor passes and passes that the user filtered out. +void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { + // Skip hidden passes and passes that the user filtered out. if (!shouldPrintAfterPass || isHiddenPass(pass) || !shouldPrintAfterPass(pass)) return; out << formatv("*** IR Dump After {0} ***", pass->getName()); - printIR(ir, printModuleScope, out); + printIR(op, printModuleScope, out); out << "\n\n"; } -void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, - const llvm::Any &ir) { +void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { // Skip adaptor passes and passes that the user filtered out. if (!shouldPrintAfterPass || isAdaptorPass(pass) || !shouldPrintAfterPass(pass)) return; out << formatv("*** IR Dump After {0} Failed ***", pass->getName()); - printIR(ir, printModuleScope, out); + printIR(op, printModuleScope, out); out << "\n\n"; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 35d96634cf1..ba3b4742cc7 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -393,40 +393,40 @@ PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {} PassInstrumentor::~PassInstrumentor() {} /// See PassInstrumentation::runBeforePass for details. -void PassInstrumentor::runBeforePass(Pass *pass, const llvm::Any &ir) { +void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : impl->instrumentations) - instr->runBeforePass(pass, ir); + instr->runBeforePass(pass, op); } /// See PassInstrumentation::runAfterPass for details. -void PassInstrumentor::runAfterPass(Pass *pass, const llvm::Any &ir) { +void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : llvm::reverse(impl->instrumentations)) - instr->runAfterPass(pass, ir); + instr->runAfterPass(pass, op); } /// See PassInstrumentation::runAfterPassFailed for details. -void PassInstrumentor::runAfterPassFailed(Pass *pass, const llvm::Any &ir) { +void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : llvm::reverse(impl->instrumentations)) - instr->runAfterPassFailed(pass, ir); + instr->runAfterPassFailed(pass, op); } /// See PassInstrumentation::runBeforeAnalysis for details. void PassInstrumentor::runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &ir) { + Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : impl->instrumentations) - instr->runBeforeAnalysis(name, id, ir); + instr->runBeforeAnalysis(name, id, op); } /// See PassInstrumentation::runAfterAnalysis for details. void PassInstrumentor::runAfterAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &ir) { + Operation *op) { llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); for (auto &instr : llvm::reverse(impl->instrumentations)) - instr->runAfterAnalysis(name, id, ir); + instr->runAfterAnalysis(name, id, op); } /// Add the given instrumentation to the collection. This takes ownership over diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index b4f375628c7..91b838cfc1e 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -154,19 +154,16 @@ struct PassTiming : public PassInstrumentation { ~PassTiming() { print(); } /// Setup the instrumentation hooks. - void runBeforePass(Pass *pass, const llvm::Any &) override { - startPassTimer(pass); - } - void runAfterPass(Pass *pass, const llvm::Any &) override; - void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override { - runAfterPass(pass, ir); + void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); } + void runAfterPass(Pass *pass, Operation *) override; + void runAfterPassFailed(Pass *pass, Operation *op) override { + runAfterPass(pass, op); } void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, - const llvm::Any &) override { + Operation *) override { startAnalysisTimer(name, id); } - void runAfterAnalysis(llvm::StringRef, AnalysisID *, - const llvm::Any &) override; + void runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) override; /// Print and clear the timing results. void print(); @@ -243,7 +240,7 @@ void PassTiming::startAnalysisTimer(llvm::StringRef name, AnalysisID *id) { } /// Stop a pass timer. -void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) { +void PassTiming::runAfterPass(Pass *pass, Operation *) { auto tid = llvm::get_threadid(); auto &activeTimers = activeThreadTimers[tid]; assert(!activeTimers.empty() && "expected active timer"); @@ -277,8 +274,7 @@ void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) { } /// Stop a timer. -void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *, - const llvm::Any &) { +void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) { auto &activeTimers = activeThreadTimers[llvm::get_threadid()]; assert(!activeTimers.empty() && "expected active timer"); Timer *timer = activeTimers.pop_back_val(); |