diff options
| author | River Riddle <riverriddle@google.com> | 2019-11-01 14:47:42 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-01 14:48:16 -0700 |
| commit | 445cc3f6dd74e86575153a95ecfb8754d6d5b726 (patch) | |
| tree | d7e931817f76f94ce6def37a2704254f4a155c6c /mlir/lib/IR | |
| parent | f143fbfa77ffd6a7da030be6009d2ef662d1e3e0 (diff) | |
| download | bcm5719-llvm-445cc3f6dd74e86575153a95ecfb8754d6d5b726.tar.gz bcm5719-llvm-445cc3f6dd74e86575153a95ecfb8754d6d5b726.zip | |
Add DialectAsmParser/Printer classes to simplify dialect attribute and type parsing.
These classes are functionally similar to the OpAsmParser/Printer classes and provide hooks for parsing attributes/tokens/types/etc. This change merely sets up the base infrastructure and updates the parser hooks, followups will add hooks as needed to simplify existing handrolled dialect parsers.
This has various different benefits:
*) Attribute/Type parsing is much simpler to define.
*) Dialect attributes/types that contain other attributes/types can now use aliases.
*) It provides a 'spec' with which we may use in the future to auto-generate parsers/printers.
*) Error messages emitted by attribute/type parsers can provide character exact locations rather than "beginning of the string"
PiperOrigin-RevId: 278005322
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 96 | ||||
| -rw-r--r-- | mlir/lib/IR/Dialect.cpp | 9 |
2 files changed, 78 insertions, 27 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 0200e983966..0e6b7882e14 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" @@ -53,6 +54,8 @@ void OperationName::print(raw_ostream &os) const { os << getStringRef(); } void OperationName::dump() const { print(llvm::errs()); } +DialectAsmPrinter::~DialectAsmPrinter() {} + OpAsmPrinter::~OpAsmPrinter() {} //===----------------------------------------------------------------------===// @@ -391,6 +394,9 @@ public: : os(printer.os), printerFlags(printer.printerFlags), state(printer.state) {} + /// Returns the output stream of the printer. + raw_ostream &getStream() { return os; } + template <typename Container, typename UnaryFunctor> inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { mlir::interleaveComma(c, os, each_fn); @@ -420,6 +426,9 @@ protected: void printLocationInternal(LocationAttr loc, bool pretty = false); void printDenseElementsAttr(DenseElementsAttr attr); + void printDialectAttribute(Attribute attr); + void printDialectType(Type type); + /// This enum is used to represent the binding strength of the enclosing /// context that an AffineExprStorage is being printed in, so we can /// intelligently produce parens. @@ -715,19 +724,9 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { } switch (attr.getKind()) { - default: { - auto &dialect = attr.getDialect(); - - // Ask the dialect to serialize the attribute to a string. - std::string attrName; - { - llvm::raw_string_ostream attrNameStr(attrName); - dialect.printAttribute(attr, attrNameStr); - } + default: + return printDialectAttribute(attr); - printDialectSymbol(os, "#", dialect.getNamespace(), attrName); - break; - } case StandardAttributes::Opaque: { auto opaqueAttr = attr.cast<OpaqueAttr>(); printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), @@ -950,19 +949,9 @@ void ModulePrinter::printType(Type type) { } switch (type.getKind()) { - default: { - auto &dialect = type.getDialect(); - - // Ask the dialect to serialize the type to a string. - std::string typeName; - { - llvm::raw_string_ostream typeNameStr(typeName); - dialect.printType(type, typeNameStr); - } + default: + return printDialectType(type); - printDialectSymbol(os, "!", dialect.getNamespace(), typeName); - return; - } case Type::Kind::Opaque: { auto opaqueTy = type.cast<OpaqueType>(); printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), @@ -1073,6 +1062,65 @@ void ModulePrinter::printType(Type type) { } //===----------------------------------------------------------------------===// +// CustomDialectAsmPrinter +//===----------------------------------------------------------------------===// + +namespace { +/// This class provides the main specialication of the DialectAsmPrinter that is +/// used to provide support for print attributes and types. This hooks allows +/// for dialects to hook into the main ModulePrinter. +struct CustomDialectAsmPrinter : public DialectAsmPrinter { +public: + CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} + ~CustomDialectAsmPrinter() override {} + + raw_ostream &getStream() const override { return printer.getStream(); } + + /// Print the given attribute to the stream. + void printAttribute(Attribute attr) override { printer.printAttribute(attr); } + + /// Print the given floating point value in a stablized form. + void printFloat(const APFloat &value) override { + printFloatValue(value, getStream()); + } + + /// Print the given type to the stream. + void printType(Type type) override { printer.printType(type); } + + /// The main module printer. + ModulePrinter &printer; +}; +} // end anonymous namespace + +void ModulePrinter::printDialectAttribute(Attribute attr) { + auto &dialect = attr.getDialect(); + + // Ask the dialect to serialize the attribute to a string. + std::string attrName; + { + llvm::raw_string_ostream attrNameStr(attrName); + ModulePrinter subPrinter(attrNameStr, printerFlags, state); + CustomDialectAsmPrinter printer(subPrinter); + dialect.printAttribute(attr, printer); + } + printDialectSymbol(os, "#", dialect.getNamespace(), attrName); +} + +void ModulePrinter::printDialectType(Type type) { + auto &dialect = type.getDialect(); + + // Ask the dialect to serialize the type to a string. + std::string typeName; + { + llvm::raw_string_ostream typeNameStr(typeName); + ModulePrinter subPrinter(typeNameStr, printerFlags, state); + CustomDialectAsmPrinter printer(subPrinter); + dialect.printType(type, printer); + } + printDialectSymbol(os, "!", dialect.getNamespace(), typeName); +} + +//===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index f8539c01d97..7882e4f1f19 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectHooks.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -28,6 +29,8 @@ using namespace mlir; using namespace detail; +DialectAsmParser::~DialectAsmParser() {} + //===----------------------------------------------------------------------===// // Dialect Registration //===----------------------------------------------------------------------===// @@ -99,7 +102,7 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, } /// Parse an attribute registered to this dialect. -Attribute Dialect::parseAttribute(StringRef attrData, Type type, +Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type, Location loc) const { emitError(loc) << "dialect '" << getNamespace() << "' provides no attribute parsing hook"; @@ -107,11 +110,11 @@ Attribute Dialect::parseAttribute(StringRef attrData, Type type, } /// Parse a type registered to this dialect. -Type Dialect::parseType(StringRef tyData, Location loc) const { +Type Dialect::parseType(DialectAsmParser &parser, Location loc) const { // If this dialect allows unknown types, then represent this with OpaqueType. if (allowsUnknownTypes()) { auto ns = Identifier::get(getNamespace(), getContext()); - return OpaqueType::get(ns, tyData, getContext()); + return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); } emitError(loc) << "dialect '" << getNamespace() |

