summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/g3doc/WritingAPass.md3
-rw-r--r--mlir/include/mlir/Pass/PassInstrumentation.h73
-rw-r--r--mlir/lib/Pass/IRPrinting.cpp43
-rw-r--r--mlir/lib/Pass/Pass.cpp20
-rw-r--r--mlir/lib/Pass/PassTiming.cpp20
5 files changed, 61 insertions, 98 deletions
diff --git a/mlir/g3doc/WritingAPass.md b/mlir/g3doc/WritingAPass.md
index dc06ace2519..47e57df6b77 100644
--- a/mlir/g3doc/WritingAPass.md
+++ b/mlir/g3doc/WritingAPass.md
@@ -389,8 +389,7 @@ struct DominanceCounterInstrumentation : public PassInstrumentation {
unsigned &count;
DominanceCounterInstrumentation(unsigned &count) : count(count) {}
- void runAfterAnalysis(llvm::StringRef, AnalysisID *id,
- const llvm::Any &) override {
+ void runAfterAnalysis(llvm::StringRef, AnalysisID *id, Operation *) override {
if (id == AnalysisID::getID<DominanceInfo>())
++count;
}
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 40358329f45..46df6fdd877 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -20,11 +20,11 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/Any.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
using AnalysisID = ClassID;
+class Operation;
class Pass;
namespace detail {
@@ -39,32 +39,32 @@ public:
virtual ~PassInstrumentation() = 0;
/// A callback to run before a pass is executed. This function takes a pointer
- /// to the pass to be executed, as well as an llvm::Any holding a pointer to
- /// the IR unit being transformed on.
- virtual void runBeforePass(Pass *pass, const llvm::Any &ir) {}
+ /// to the pass to be executed, as well as the current operation being
+ /// operated on.
+ virtual void runBeforePass(Pass *pass, Operation *op) {}
/// A callback to run after a pass is successfully executed. This function
- /// takes a pointer to the pass to be executed, as well as an llvm::Any
- /// holding a pointer to the IR unit being transformed on.
- virtual void runAfterPass(Pass *pass, const llvm::Any &ir) {}
+ /// takes a pointer to the pass to be executed, as well as the current
+ /// operation being operated on.
+ virtual void runAfterPass(Pass *pass, Operation *op) {}
/// A callback to run when a pass execution fails. This function takes a
- /// pointer to the pass that was being executed, as well as an llvm::Any
- /// holding a pointer to the IR unit that was being transformed. Note
- /// that the ir unit may be in an invalid state.
- virtual void runAfterPassFailed(Pass *pass, const llvm::Any &ir) {}
+ /// pointer to the pass that was being executed, as well as the current
+ /// operation being operated on. Note that the operation may be in an invalid
+ /// state.
+ virtual void runAfterPassFailed(Pass *pass, Operation *op) {}
/// A callback to run before an analysis is computed. This function takes the
- /// name of the analysis to be computed, its AnalysisID, as well as an
- /// llvm::Any holding a pointer to the IR unit being analyzed on.
+ /// name of the analysis to be computed, its AnalysisID, as well as the
+ /// current operation being analyzed.
virtual void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &ir) {}
+ Operation *op) {}
/// A callback to run before an analysis is computed. This function takes the
- /// name of the analysis that was computed, its AnalysisID, as well as an
- /// llvm::Any holding a pointer to the IR unit that was analyzed.
+ /// name of the analysis that was computed, its AnalysisID, as well as the
+ /// current operation being analyzed.
virtual void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &ir) {}
+ Operation *op) {}
};
/// This class holds a collection of PassInstrumentation objects, and invokes
@@ -77,54 +77,25 @@ public:
~PassInstrumentor();
/// See PassInstrumentation::runBeforePass for details.
- template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) {
- runBeforePass(pass, llvm::Any(ir));
- }
+ void runBeforePass(Pass *pass, Operation *op);
/// See PassInstrumentation::runAfterPass for details.
- template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) {
- runAfterPass(pass, llvm::Any(ir));
- }
+ void runAfterPass(Pass *pass, Operation *op);
/// See PassInstrumentation::runAfterPassFailed for details.
- template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) {
- runAfterPassFailed(pass, llvm::Any(ir));
- }
+ void runAfterPassFailed(Pass *pass, Operation *op);
/// See PassInstrumentation::runBeforeAnalysis for details.
- template <typename IRUnitT>
- void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
- runBeforeAnalysis(name, id, llvm::Any(ir));
- }
+ void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op);
/// See PassInstrumentation::runAfterAnalysis for details.
- template <typename IRUnitT>
- void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
- runAfterAnalysis(name, id, llvm::Any(ir));
- }
+ void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op);
/// Add the given instrumentation to the collection. This takes ownership over
/// the given pointer.
void addInstrumentation(PassInstrumentation *pi);
private:
- /// See PassInstrumentation::runBeforePass for details.
- void runBeforePass(Pass *pass, const llvm::Any &ir);
-
- /// See PassInstrumentation::runAfterPass for details.
- void runAfterPass(Pass *pass, const llvm::Any &ir);
-
- /// See PassInstrumentation::runAfterPassFailed for details.
- void runAfterPassFailed(Pass *pass, const llvm::Any &ir);
-
- /// See PassInstrumentation::runBeforeAnalysis for details.
- void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &ir);
-
- /// See PassInstrumentation::runAfterAnalysis for details.
- void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &ir);
-
std::unique_ptr<detail::PassInstrumentorImpl> impl;
};
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 2de4b05a36c..bc661979e59 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -27,8 +27,8 @@ using namespace mlir::detail;
namespace {
class IRPrinterInstrumentation : public PassInstrumentation {
public:
- /// A filter function to decide if the given ir should be printed. Returns
- /// true if the ir should be printed, false otherwise.
+ /// A filter function to decide if the given pass should be printed. Returns
+ /// true if the pass should be printed, false otherwise.
using ShouldPrintFn = std::function<bool(Pass *)>;
IRPrinterInstrumentation(ShouldPrintFn &&shouldPrintBeforePass,
@@ -43,9 +43,9 @@ public:
private:
/// Instrumentation hooks.
- void runBeforePass(Pass *pass, const llvm::Any &ir) override;
- void runAfterPass(Pass *pass, const llvm::Any &ir) override;
- void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override;
+ void runBeforePass(Pass *pass, Operation *op) override;
+ void runAfterPass(Pass *pass, Operation *op) override;
+ void runAfterPassFailed(Pass *pass, Operation *op) override;
/// Filter functions for before and after pass execution.
ShouldPrintFn shouldPrintBeforePass, shouldPrintAfterPass;
@@ -63,12 +63,10 @@ static bool isHiddenPass(Pass *pass) {
return isAdaptorPass(pass) || isVerifierPass(pass);
}
-static void printIR(const llvm::Any &ir, bool printModuleScope,
- raw_ostream &out) {
+static void printIR(Operation *op, bool printModuleScope, raw_ostream &out) {
// Check for printing at module scope.
- if (printModuleScope && llvm::any_isa<FuncOp>(ir)) {
- FuncOp function = llvm::any_cast<FuncOp>(ir);
-
+ auto function = dyn_cast<FuncOp>(op);
+ if (printModuleScope && function) {
// Print the function name and a newline before the Module.
out << " (function: " << function.getName() << ")\n";
function.getParentOfType<ModuleOp>().print(out);
@@ -79,45 +77,44 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
out << "\n";
// Print the given function.
- if (llvm::any_isa<FuncOp>(ir)) {
- llvm::any_cast<FuncOp>(ir).print(out);
+ if (function) {
+ function.print(out);
return;
}
// Print the given module.
- assert(llvm::any_isa<ModuleOp>(ir) && "unexpected IR unit");
- llvm::any_cast<ModuleOp>(ir).print(out);
+ assert(isa<ModuleOp>(op) && "unexpected IR unit");
+ cast<ModuleOp>(op).print(out);
}
/// Instrumentation hooks.
-void IRPrinterInstrumentation::runBeforePass(Pass *pass, const llvm::Any &ir) {
- // Skip adaptor passes and passes that the user filtered out.
+void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
+ // Skip hidden passes and passes that the user filtered out.
if (!shouldPrintBeforePass || isHiddenPass(pass) ||
!shouldPrintBeforePass(pass))
return;
out << formatv("*** IR Dump Before {0} ***", pass->getName());
- printIR(ir, printModuleScope, out);
+ printIR(op, printModuleScope, out);
out << "\n\n";
}
-void IRPrinterInstrumentation::runAfterPass(Pass *pass, const llvm::Any &ir) {
- // Skip adaptor passes and passes that the user filtered out.
+void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
+ // Skip hidden passes and passes that the user filtered out.
if (!shouldPrintAfterPass || isHiddenPass(pass) ||
!shouldPrintAfterPass(pass))
return;
out << formatv("*** IR Dump After {0} ***", pass->getName());
- printIR(ir, printModuleScope, out);
+ printIR(op, printModuleScope, out);
out << "\n\n";
}
-void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass,
- const llvm::Any &ir) {
+void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
// Skip adaptor passes and passes that the user filtered out.
if (!shouldPrintAfterPass || isAdaptorPass(pass) ||
!shouldPrintAfterPass(pass))
return;
out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
- printIR(ir, printModuleScope, out);
+ printIR(op, printModuleScope, out);
out << "\n\n";
}
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 35d96634cf1..ba3b4742cc7 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -393,40 +393,40 @@ PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
PassInstrumentor::~PassInstrumentor() {}
/// See PassInstrumentation::runBeforePass for details.
-void PassInstrumentor::runBeforePass(Pass *pass, const llvm::Any &ir) {
+void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : impl->instrumentations)
- instr->runBeforePass(pass, ir);
+ instr->runBeforePass(pass, op);
}
/// See PassInstrumentation::runAfterPass for details.
-void PassInstrumentor::runAfterPass(Pass *pass, const llvm::Any &ir) {
+void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : llvm::reverse(impl->instrumentations))
- instr->runAfterPass(pass, ir);
+ instr->runAfterPass(pass, op);
}
/// See PassInstrumentation::runAfterPassFailed for details.
-void PassInstrumentor::runAfterPassFailed(Pass *pass, const llvm::Any &ir) {
+void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : llvm::reverse(impl->instrumentations))
- instr->runAfterPassFailed(pass, ir);
+ instr->runAfterPassFailed(pass, op);
}
/// See PassInstrumentation::runBeforeAnalysis for details.
void PassInstrumentor::runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &ir) {
+ Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : impl->instrumentations)
- instr->runBeforeAnalysis(name, id, ir);
+ instr->runBeforeAnalysis(name, id, op);
}
/// See PassInstrumentation::runAfterAnalysis for details.
void PassInstrumentor::runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &ir) {
+ Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : llvm::reverse(impl->instrumentations))
- instr->runAfterAnalysis(name, id, ir);
+ instr->runAfterAnalysis(name, id, op);
}
/// Add the given instrumentation to the collection. This takes ownership over
diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index b4f375628c7..91b838cfc1e 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -154,19 +154,16 @@ struct PassTiming : public PassInstrumentation {
~PassTiming() { print(); }
/// Setup the instrumentation hooks.
- void runBeforePass(Pass *pass, const llvm::Any &) override {
- startPassTimer(pass);
- }
- void runAfterPass(Pass *pass, const llvm::Any &) override;
- void runAfterPassFailed(Pass *pass, const llvm::Any &ir) override {
- runAfterPass(pass, ir);
+ void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); }
+ void runAfterPass(Pass *pass, Operation *) override;
+ void runAfterPassFailed(Pass *pass, Operation *op) override {
+ runAfterPass(pass, op);
}
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
- const llvm::Any &) override {
+ Operation *) override {
startAnalysisTimer(name, id);
}
- void runAfterAnalysis(llvm::StringRef, AnalysisID *,
- const llvm::Any &) override;
+ void runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) override;
/// Print and clear the timing results.
void print();
@@ -243,7 +240,7 @@ void PassTiming::startAnalysisTimer(llvm::StringRef name, AnalysisID *id) {
}
/// Stop a pass timer.
-void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) {
+void PassTiming::runAfterPass(Pass *pass, Operation *) {
auto tid = llvm::get_threadid();
auto &activeTimers = activeThreadTimers[tid];
assert(!activeTimers.empty() && "expected active timer");
@@ -277,8 +274,7 @@ void PassTiming::runAfterPass(Pass *pass, const llvm::Any &) {
}
/// Stop a timer.
-void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *,
- const llvm::Any &) {
+void PassTiming::runAfterAnalysis(llvm::StringRef, AnalysisID *, Operation *) {
auto &activeTimers = activeThreadTimers[llvm::get_threadid()];
assert(!activeTimers.empty() && "expected active timer");
Timer *timer = activeTimers.pop_back_val();
OpenPOWER on IntegriCloud