summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2020-01-14 15:23:05 -0800
committerRiver Riddle <riverriddle@google.com>2020-01-14 15:23:31 -0800
commitfa9dd8336bbd1167926f93fe2018d0c47839d5d6 (patch)
treef151e58b0e13a0a551105863f832192c0126a1e9 /mlir/lib/IR
parent23058f9dd4d7e18239fd63b6da52549514b45fda (diff)
downloadbcm5719-llvm-fa9dd8336bbd1167926f93fe2018d0c47839d5d6.tar.gz
bcm5719-llvm-fa9dd8336bbd1167926f93fe2018d0c47839d5d6.zip
[mlir] Refactor ModuleState into AsmState and expose it to users.
Summary: This allows for users to cache printer state, which can be costly to recompute. Each of the IR print methods gain a new overload taking this new state class. Depends On D72293 Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D72294
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp108
1 files changed, 69 insertions, 39 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index afb94fea0cd..6f9d76e20f0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
@@ -37,6 +38,7 @@
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
using namespace mlir;
+using namespace mlir::detail;
void Identifier::print(raw_ostream &os) const { os << str(); }
@@ -756,13 +758,14 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
}
//===----------------------------------------------------------------------===//
-// ModuleState
+// AsmState
//===----------------------------------------------------------------------===//
-namespace {
-class ModuleState {
+namespace mlir {
+namespace detail {
+class AsmStateImpl {
public:
- explicit ModuleState(Operation *op)
+ explicit AsmStateImpl(Operation *op)
: interfaces(op->getContext()), nameState(op, interfaces) {}
/// Initialize the alias state to enable the printing of aliases.
@@ -792,7 +795,11 @@ private:
/// The state used for SSA value names.
SSANameState nameState;
};
-} // end anonymous namespace
+} // end namespace detail
+} // end namespace mlir
+
+AsmState::AsmState(Operation *op) : impl(std::make_unique<AsmStateImpl>(op)) {}
+AsmState::~AsmState() {}
//===----------------------------------------------------------------------===//
// ModulePrinter
@@ -802,7 +809,7 @@ namespace {
class ModulePrinter {
public:
ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
- ModuleState *state = nullptr)
+ AsmStateImpl *state = nullptr)
: os(os), printerFlags(flags), state(state) {}
explicit ModulePrinter(ModulePrinter &printer)
: os(printer.os), printerFlags(printer.printerFlags),
@@ -816,8 +823,6 @@ public:
mlir::interleaveComma(c, os, each_fn);
}
- void print(ModuleOp module);
-
/// Print the given attribute. If 'mayElideType' is true, some attributes are
/// printed without the type when the type matches the default used in the
/// parser (for example i64 is the default for integer attributes).
@@ -862,7 +867,7 @@ protected:
OpPrintingFlags printerFlags;
/// An optional printer state for the module.
- ModuleState *state;
+ AsmStateImpl *state;
};
} // end anonymous namespace
@@ -1815,10 +1820,12 @@ namespace {
/// This class contains the logic for printing operations, regions, and blocks.
class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
public:
- explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) {
- assert(state && "expected valid state when printing operation");
- }
+ explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
+ AsmStateImpl &state)
+ : ModulePrinter(os, flags, &state) {}
+ /// Print the given top-level module.
+ void print(ModuleOp op);
/// Print the given operation with its indent and location.
void print(Operation *op);
/// Print the bare location, not including indentation/location/etc.
@@ -1903,6 +1910,15 @@ private:
};
} // end anonymous namespace
+void OperationPrinter::print(ModuleOp op) {
+ // Output the aliases at the top level.
+ state->getAliasState().printAttributeAliases(os);
+ state->getAliasState().printTypeAliases(os);
+
+ // Print the module.
+ print(op.getOperation());
+}
+
void OperationPrinter::print(Operation *op) {
os.indent(currentIndent);
printOperation(op);
@@ -2108,18 +2124,6 @@ void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
});
}
-void ModulePrinter::print(ModuleOp module) {
- assert(state && "expected valid state when printing an operation");
-
- // Output the aliases at the top level.
- state->getAliasState().printAttributeAliases(os);
- state->getAliasState().printTypeAliases(os);
-
- // Print the module.
- OperationPrinter(*this).print(module);
- os << '\n';
-}
-
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
@@ -2179,18 +2183,34 @@ void Value::print(raw_ostream &os) {
assert(isa<BlockArgument>());
os << "<block argument>\n";
}
+void Value::print(raw_ostream &os, AsmState &state) {
+ if (auto *op = getDefiningOp())
+ return op->print(os, state);
+
+ // TODO: Improve this.
+ assert(isa<BlockArgument>());
+ os << "<block argument>\n";
+}
void Value::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
+void Value::printAsOperand(raw_ostream &os, AsmState &state) {
+ // TODO(riverriddle) This doesn't necessarily capture all potential cases.
+ // Currently, region arguments can be shadowed when printing the main
+ // operation. If the IR hasn't been printed, this will produce the old SSA
+ // name and not the shadowed name.
+ state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
+ os);
+}
+
void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
// Handle top-level operations or local printing.
if (!getParent() || flags.shouldUseLocalScope()) {
- ModuleState state(this);
- ModulePrinter modulePrinter(os, flags, &state);
- OperationPrinter(modulePrinter).print(this);
+ AsmState state(this);
+ OperationPrinter(os, flags, state.getImpl()).print(this);
return;
}
@@ -2203,9 +2223,11 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
- ModuleState state(parentOp);
- ModulePrinter modulePrinter(os, flags, &state);
- OperationPrinter(modulePrinter).print(this);
+ AsmState state(parentOp);
+ print(os, state, flags);
+}
+void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
+ OperationPrinter(os, flags, state.getImpl()).print(this);
}
void Operation::dump() {
@@ -2223,9 +2245,11 @@ void Block::print(raw_ostream &os) {
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
- ModuleState state(parentOp);
- ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
- OperationPrinter(modulePrinter).print(this);
+ AsmState state(parentOp);
+ print(os, state);
+}
+void Block::print(raw_ostream &os, AsmState &state) {
+ OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
}
void Block::dump() { print(llvm::errs()); }
@@ -2241,18 +2265,24 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
- ModuleState state(parentOp);
- ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
- OperationPrinter(modulePrinter).printBlockName(this);
+ AsmState state(parentOp);
+ printAsOperand(os, state);
+}
+void Block::printAsOperand(raw_ostream &os, AsmState &state) {
+ OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
+ printer.printBlockName(this);
}
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
- ModuleState state(*this);
+ AsmState state(*this);
// Don't populate aliases when printing at local scope.
if (!flags.shouldUseLocalScope())
- state.initializeAliases(*this);
- ModulePrinter(os, flags, &state).print(*this);
+ state.getImpl().initializeAliases(*this);
+ print(os, state, flags);
+}
+void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
+ OperationPrinter(os, flags, state.getImpl()).print(*this);
}
void ModuleOp::dump() { print(llvm::errs()); }
OpenPOWER on IntegriCloud