diff options
| author | River Riddle <riverriddle@google.com> | 2020-01-14 15:23:05 -0800 |
|---|---|---|
| committer | River Riddle <riverriddle@google.com> | 2020-01-14 15:23:31 -0800 |
| commit | fa9dd8336bbd1167926f93fe2018d0c47839d5d6 (patch) | |
| tree | f151e58b0e13a0a551105863f832192c0126a1e9 /mlir/lib/IR | |
| parent | 23058f9dd4d7e18239fd63b6da52549514b45fda (diff) | |
| download | bcm5719-llvm-fa9dd8336bbd1167926f93fe2018d0c47839d5d6.tar.gz bcm5719-llvm-fa9dd8336bbd1167926f93fe2018d0c47839d5d6.zip | |
[mlir] Refactor ModuleState into AsmState and expose it to users.
Summary:
This allows for users to cache printer state, which can be costly to recompute. Each of the IR print methods gain a new overload taking this new state class.
Depends On D72293
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D72294
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 108 |
1 files changed, 69 insertions, 39 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index afb94fea0cd..6f9d76e20f0 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -37,6 +38,7 @@ #include "llvm/Support/Regex.h" #include "llvm/Support/SaveAndRestore.h" using namespace mlir; +using namespace mlir::detail; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -756,13 +758,14 @@ StringRef SSANameState::uniqueValueName(StringRef name) { } //===----------------------------------------------------------------------===// -// ModuleState +// AsmState //===----------------------------------------------------------------------===// -namespace { -class ModuleState { +namespace mlir { +namespace detail { +class AsmStateImpl { public: - explicit ModuleState(Operation *op) + explicit AsmStateImpl(Operation *op) : interfaces(op->getContext()), nameState(op, interfaces) {} /// Initialize the alias state to enable the printing of aliases. @@ -792,7 +795,11 @@ private: /// The state used for SSA value names. SSANameState nameState; }; -} // end anonymous namespace +} // end namespace detail +} // end namespace mlir + +AsmState::AsmState(Operation *op) : impl(std::make_unique<AsmStateImpl>(op)) {} +AsmState::~AsmState() {} //===----------------------------------------------------------------------===// // ModulePrinter @@ -802,7 +809,7 @@ namespace { class ModulePrinter { public: ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, - ModuleState *state = nullptr) + AsmStateImpl *state = nullptr) : os(os), printerFlags(flags), state(state) {} explicit ModulePrinter(ModulePrinter &printer) : os(printer.os), printerFlags(printer.printerFlags), @@ -816,8 +823,6 @@ public: mlir::interleaveComma(c, os, each_fn); } - void print(ModuleOp module); - /// Print the given attribute. If 'mayElideType' is true, some attributes are /// printed without the type when the type matches the default used in the /// parser (for example i64 is the default for integer attributes). @@ -862,7 +867,7 @@ protected: OpPrintingFlags printerFlags; /// An optional printer state for the module. - ModuleState *state; + AsmStateImpl *state; }; } // end anonymous namespace @@ -1815,10 +1820,12 @@ namespace { /// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: - explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) { - assert(state && "expected valid state when printing operation"); - } + explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, + AsmStateImpl &state) + : ModulePrinter(os, flags, &state) {} + /// Print the given top-level module. + void print(ModuleOp op); /// Print the given operation with its indent and location. void print(Operation *op); /// Print the bare location, not including indentation/location/etc. @@ -1903,6 +1910,15 @@ private: }; } // end anonymous namespace +void OperationPrinter::print(ModuleOp op) { + // Output the aliases at the top level. + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); + + // Print the module. + print(op.getOperation()); +} + void OperationPrinter::print(Operation *op) { os.indent(currentIndent); printOperation(op); @@ -2108,18 +2124,6 @@ void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, }); } -void ModulePrinter::print(ModuleOp module) { - assert(state && "expected valid state when printing an operation"); - - // Output the aliases at the top level. - state->getAliasState().printAttributeAliases(os); - state->getAliasState().printTypeAliases(os); - - // Print the module. - OperationPrinter(*this).print(module); - os << '\n'; -} - //===----------------------------------------------------------------------===// // print and dump methods //===----------------------------------------------------------------------===// @@ -2179,18 +2183,34 @@ void Value::print(raw_ostream &os) { assert(isa<BlockArgument>()); os << "<block argument>\n"; } +void Value::print(raw_ostream &os, AsmState &state) { + if (auto *op = getDefiningOp()) + return op->print(os, state); + + // TODO: Improve this. + assert(isa<BlockArgument>()); + os << "<block argument>\n"; +} void Value::dump() { print(llvm::errs()); llvm::errs() << "\n"; } +void Value::printAsOperand(raw_ostream &os, AsmState &state) { + // TODO(riverriddle) This doesn't necessarily capture all potential cases. + // Currently, region arguments can be shadowed when printing the main + // operation. If the IR hasn't been printed, this will produce the old SSA + // name and not the shadowed name. + state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, + os); +} + void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations or local printing. if (!getParent() || flags.shouldUseLocalScope()) { - ModuleState state(this); - ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(this); + OperationPrinter(os, flags, state.getImpl()).print(this); return; } @@ -2203,9 +2223,11 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) { while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(parentOp); + print(os, state, flags); +} +void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { + OperationPrinter(os, flags, state.getImpl()).print(this); } void Operation::dump() { @@ -2223,9 +2245,11 @@ void Block::print(raw_ostream &os) { while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(parentOp); + print(os, state); +} +void Block::print(raw_ostream &os, AsmState &state) { + OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); } void Block::dump() { print(llvm::errs()); } @@ -2241,18 +2265,24 @@ void Block::printAsOperand(raw_ostream &os, bool printType) { while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(modulePrinter).printBlockName(this); + AsmState state(parentOp); + printAsOperand(os, state); +} +void Block::printAsOperand(raw_ostream &os, AsmState &state) { + OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); + printer.printBlockName(this); } void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - ModuleState state(*this); + AsmState state(*this); // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) - state.initializeAliases(*this); - ModulePrinter(os, flags, &state).print(*this); + state.getImpl().initializeAliases(*this); + print(os, state, flags); +} +void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { + OperationPrinter(os, flags, state.getImpl()).print(*this); } void ModuleOp::dump() { print(llvm::errs()); } |

