summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-01 14:47:42 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-01 14:48:16 -0700
commit445cc3f6dd74e86575153a95ecfb8754d6d5b726 (patch)
treed7e931817f76f94ce6def37a2704254f4a155c6c /mlir/lib/IR
parentf143fbfa77ffd6a7da030be6009d2ef662d1e3e0 (diff)
downloadbcm5719-llvm-445cc3f6dd74e86575153a95ecfb8754d6d5b726.tar.gz
bcm5719-llvm-445cc3f6dd74e86575153a95ecfb8754d6d5b726.zip
Add DialectAsmParser/Printer classes to simplify dialect attribute and type parsing.
These classes are functionally similar to the OpAsmParser/Printer classes and provide hooks for parsing attributes/tokens/types/etc. This change merely sets up the base infrastructure and updates the parser hooks, followups will add hooks as needed to simplify existing handrolled dialect parsers. This has various different benefits: *) Attribute/Type parsing is much simpler to define. *) Dialect attributes/types that contain other attributes/types can now use aliases. *) It provides a 'spec' with which we may use in the future to auto-generate parsers/printers. *) Error messages emitted by attribute/type parsers can provide character exact locations rather than "beginning of the string" PiperOrigin-RevId: 278005322
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp96
-rw-r--r--mlir/lib/IR/Dialect.cpp9
2 files changed, 78 insertions, 27 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0200e983966..0e6b7882e14 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
@@ -53,6 +54,8 @@ void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
void OperationName::dump() const { print(llvm::errs()); }
+DialectAsmPrinter::~DialectAsmPrinter() {}
+
OpAsmPrinter::~OpAsmPrinter() {}
//===----------------------------------------------------------------------===//
@@ -391,6 +394,9 @@ public:
: os(printer.os), printerFlags(printer.printerFlags),
state(printer.state) {}
+ /// Returns the output stream of the printer.
+ raw_ostream &getStream() { return os; }
+
template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
mlir::interleaveComma(c, os, each_fn);
@@ -420,6 +426,9 @@ protected:
void printLocationInternal(LocationAttr loc, bool pretty = false);
void printDenseElementsAttr(DenseElementsAttr attr);
+ void printDialectAttribute(Attribute attr);
+ void printDialectType(Type type);
+
/// This enum is used to represent the binding strength of the enclosing
/// context that an AffineExprStorage is being printed in, so we can
/// intelligently produce parens.
@@ -715,19 +724,9 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
}
switch (attr.getKind()) {
- default: {
- auto &dialect = attr.getDialect();
-
- // Ask the dialect to serialize the attribute to a string.
- std::string attrName;
- {
- llvm::raw_string_ostream attrNameStr(attrName);
- dialect.printAttribute(attr, attrNameStr);
- }
+ default:
+ return printDialectAttribute(attr);
- printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
- break;
- }
case StandardAttributes::Opaque: {
auto opaqueAttr = attr.cast<OpaqueAttr>();
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
@@ -950,19 +949,9 @@ void ModulePrinter::printType(Type type) {
}
switch (type.getKind()) {
- default: {
- auto &dialect = type.getDialect();
-
- // Ask the dialect to serialize the type to a string.
- std::string typeName;
- {
- llvm::raw_string_ostream typeNameStr(typeName);
- dialect.printType(type, typeNameStr);
- }
+ default:
+ return printDialectType(type);
- printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
- return;
- }
case Type::Kind::Opaque: {
auto opaqueTy = type.cast<OpaqueType>();
printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
@@ -1073,6 +1062,65 @@ void ModulePrinter::printType(Type type) {
}
//===----------------------------------------------------------------------===//
+// CustomDialectAsmPrinter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides the main specialication of the DialectAsmPrinter that is
+/// used to provide support for print attributes and types. This hooks allows
+/// for dialects to hook into the main ModulePrinter.
+struct CustomDialectAsmPrinter : public DialectAsmPrinter {
+public:
+ CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
+ ~CustomDialectAsmPrinter() override {}
+
+ raw_ostream &getStream() const override { return printer.getStream(); }
+
+ /// Print the given attribute to the stream.
+ void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
+
+ /// Print the given floating point value in a stablized form.
+ void printFloat(const APFloat &value) override {
+ printFloatValue(value, getStream());
+ }
+
+ /// Print the given type to the stream.
+ void printType(Type type) override { printer.printType(type); }
+
+ /// The main module printer.
+ ModulePrinter &printer;
+};
+} // end anonymous namespace
+
+void ModulePrinter::printDialectAttribute(Attribute attr) {
+ auto &dialect = attr.getDialect();
+
+ // Ask the dialect to serialize the attribute to a string.
+ std::string attrName;
+ {
+ llvm::raw_string_ostream attrNameStr(attrName);
+ ModulePrinter subPrinter(attrNameStr, printerFlags, state);
+ CustomDialectAsmPrinter printer(subPrinter);
+ dialect.printAttribute(attr, printer);
+ }
+ printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
+}
+
+void ModulePrinter::printDialectType(Type type) {
+ auto &dialect = type.getDialect();
+
+ // Ask the dialect to serialize the type to a string.
+ std::string typeName;
+ {
+ llvm::raw_string_ostream typeNameStr(typeName);
+ ModulePrinter subPrinter(typeNameStr, printerFlags, state);
+ CustomDialectAsmPrinter printer(subPrinter);
+ dialect.printType(type, printer);
+ }
+ printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
+}
+
+//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index f8539c01d97..7882e4f1f19 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@@ -28,6 +29,8 @@
using namespace mlir;
using namespace detail;
+DialectAsmParser::~DialectAsmParser() {}
+
//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//
@@ -99,7 +102,7 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
}
/// Parse an attribute registered to this dialect.
-Attribute Dialect::parseAttribute(StringRef attrData, Type type,
+Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const {
emitError(loc) << "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
@@ -107,11 +110,11 @@ Attribute Dialect::parseAttribute(StringRef attrData, Type type,
}
/// Parse a type registered to this dialect.
-Type Dialect::parseType(StringRef tyData, Location loc) const {
+Type Dialect::parseType(DialectAsmParser &parser, Location loc) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
- return OpaqueType::get(ns, tyData, getContext());
+ return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
}
emitError(loc) << "dialect '" << getNamespace()
OpenPOWER on IntegriCloud