summaryrefslogtreecommitdiffstats
path: root/mlir/tools/mlir-tblgen/StructsGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/StructsGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/StructsGen.cpp250
1 files changed, 250 insertions, 0 deletions
diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp
new file mode 100644
index 00000000000..576085e41eb
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -0,0 +1,250 @@
+//===- StructsGen.cpp - MLIR struct utility generator ---------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// StructsGen generates common utility functions for grouping attributes into a
+// set of structured data.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using llvm::raw_ostream;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::StringRef;
+using mlir::tblgen::StructAttr;
+
+static void
+emitStructClass(const Record &structDef, StringRef structName,
+ llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
+ StringRef description, raw_ostream &os) {
+ const char *structInfo = R"(
+// {0}
+class {1} : public mlir::DictionaryAttr)";
+ const char *structInfoEnd = R"( {
+public:
+ using DictionaryAttr::DictionaryAttr;
+ static bool classof(mlir::Attribute attr);
+)";
+ os << formatv(structInfo, description, structName) << structInfoEnd;
+
+ // Declares a constructor function for the tablegen structure.
+ // TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...);
+ const char *getInfoDecl = " static {0} get(\n";
+ const char *getInfoDeclArg = " {0} {1},\n";
+ const char *getInfoDeclEnd = " mlir::MLIRContext* context);\n\n";
+
+ os << llvm::formatv(getInfoDecl, structName);
+
+ for (auto field : fields) {
+ auto name = field.getName();
+ auto type = field.getType();
+ auto storage = type.getStorageType();
+ os << llvm::formatv(getInfoDeclArg, storage, name);
+ }
+ os << getInfoDeclEnd;
+
+ // Declares an accessor for the fields owned by the tablegen structure.
+ // namespace::storage TblgenStruct::field1() const;
+ const char *fieldInfo = R"( {0} {1}() const;
+)";
+ for (const auto field : fields) {
+ auto name = field.getName();
+ auto type = field.getType();
+ auto storage = type.getStorageType();
+ os << formatv(fieldInfo, storage, name);
+ }
+
+ os << "};\n\n";
+}
+
+static void emitStructDecl(const Record &structDef, raw_ostream &os) {
+ StructAttr structAttr(&structDef);
+ StringRef structName = structAttr.getStructClassName();
+ StringRef cppNamespace = structAttr.getCppNamespace();
+ StringRef description = structAttr.getDescription();
+ auto fields = structAttr.getAllFields();
+
+ // Wrap in the appropriate namespace.
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(cppNamespace, namespaces, "::");
+
+ for (auto ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ // Emit the struct class definition
+ emitStructClass(structDef, structName, fields, description, os);
+
+ // Close the declared namespace.
+ for (auto ns : namespaces)
+ os << "} // namespace " << ns << "\n";
+}
+
+static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ llvm::emitSourceFileHeader("Struct Utility Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
+ for (const auto *def : defs) {
+ emitStructDecl(*def, os);
+ }
+
+ return false;
+}
+
+static void emitFactoryDef(llvm::StringRef structName,
+ llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
+ raw_ostream &os) {
+ const char *getInfoDecl = "{0} {0}::get(\n";
+ const char *getInfoDeclArg = " {0} {1},\n";
+ const char *getInfoDeclEnd = " mlir::MLIRContext* context) {";
+
+ os << llvm::formatv(getInfoDecl, structName);
+
+ for (auto field : fields) {
+ auto name = field.getName();
+ auto type = field.getType();
+ auto storage = type.getStorageType();
+ os << llvm::formatv(getInfoDeclArg, storage, name);
+ }
+ os << getInfoDeclEnd;
+
+ const char *fieldStart = R"(
+ llvm::SmallVector<mlir::NamedAttribute, {0}> fields;
+)";
+ os << llvm::formatv(fieldStart, fields.size());
+
+ const char *getFieldInfo = R"(
+ assert({0});
+ auto {0}_id = mlir::Identifier::get("{0}", context);
+ fields.emplace_back({0}_id, {0});
+)";
+
+ for (auto field : fields) {
+ os << llvm::formatv(getFieldInfo, field.getName());
+ }
+
+ const char *getEndInfo = R"(
+ Attribute dict = mlir::DictionaryAttr::get(fields, context);
+ return dict.dyn_cast<{0}>();
+}
+)";
+ os << llvm::formatv(getEndInfo, structName);
+}
+
+static void emitClassofDef(llvm::StringRef structName,
+ llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
+ raw_ostream &os) {
+ const char *classofInfo = R"(
+bool {0}::classof(mlir::Attribute attr))";
+
+ const char *classofInfoHeader = R"(
+ auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
+ if (!derived)
+ return false;
+ if (derived.size() != {0})
+ return false;
+)";
+
+ os << llvm::formatv(classofInfo, structName) << " {";
+ os << llvm::formatv(classofInfoHeader, fields.size());
+
+ const char *classofArgInfo = R"(
+ auto {0} = derived.get("{0}");
+ if (!{0} || !{0}.isa<{1}>())
+ return false;
+)";
+ for (auto field : fields) {
+ auto name = field.getName();
+ auto type = field.getType();
+ auto storage = type.getStorageType();
+ os << llvm::formatv(classofArgInfo, name, storage);
+ }
+
+ const char *classofEndInfo = R"(
+ return true;
+}
+)";
+ os << classofEndInfo;
+}
+
+static void
+emitAccessorDef(llvm::StringRef structName,
+ llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
+ raw_ostream &os) {
+ const char *fieldInfo = R"(
+{0} {2}::{1}() const {
+ auto derived = this->cast<mlir::DictionaryAttr>();
+ auto {1} = derived.get("{1}");
+ assert({1} && "attribute not found.");
+ assert({1}.isa<{0}>() && "incorrect Attribute type found.");
+ return {1}.cast<{0}>();
+}
+)";
+ for (auto field : fields) {
+ auto name = field.getName();
+ auto type = field.getType();
+ auto storage = type.getStorageType();
+ os << llvm::formatv(fieldInfo, storage, name, structName);
+ }
+}
+
+static void emitStructDef(const Record &structDef, raw_ostream &os) {
+ StructAttr structAttr(&structDef);
+ StringRef cppNamespace = structAttr.getCppNamespace();
+ StringRef structName = structAttr.getStructClassName();
+ mlir::tblgen::FmtContext ctx;
+ auto fields = structAttr.getAllFields();
+
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(cppNamespace, namespaces, "::");
+
+ for (auto ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ emitFactoryDef(structName, fields, os);
+ emitClassofDef(structName, fields, os);
+ emitAccessorDef(structName, fields, os);
+
+ for (auto ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
+}
+
+static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ llvm::emitSourceFileHeader("Struct Utility Definitions", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
+ for (const auto *def : defs)
+ emitStructDef(*def, os);
+
+ return false;
+}
+
+// Registers the struct utility generator to mlir-tblgen.
+static mlir::GenRegistration
+ genStructDecls("gen-struct-attr-decls",
+ "Generate struct utility declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitStructDecls(records, os);
+ });
+
+// Registers the struct utility generator to mlir-tblgen.
+static mlir::GenRegistration
+ genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitStructDefs(records, os);
+ });
OpenPOWER on IntegriCloud