summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/TableGen/ODSDialectHook.h42
-rw-r--r--mlir/include/mlir/TableGen/Operator.h9
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp37
3 files changed, 81 insertions, 7 deletions
diff --git a/mlir/include/mlir/TableGen/ODSDialectHook.h b/mlir/include/mlir/TableGen/ODSDialectHook.h
new file mode 100644
index 00000000000..9d1ea3b6857
--- /dev/null
+++ b/mlir/include/mlir/TableGen/ODSDialectHook.h
@@ -0,0 +1,42 @@
+//===- ODSDialectHook.h - Dialect customization hooks into ODS --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ODS customization hooks for dialects to programmatically
+// emit dialect specific contents in ODS C++ code emission.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ODSDIALECTHOOK_H_
+#define MLIR_TABLEGEN_ODSDIALECTHOOK_H_
+
+#include <functional>
+
+namespace llvm {
+class StringRef;
+}
+
+namespace mlir {
+namespace tblgen {
+class Operator;
+class OpClass;
+
+// The emission function for dialect specific content. It takes in an Operator
+// and updates the OpClass accordingly.
+using DialectEmitFunction =
+ std::function<void(const Operator &srcOp, OpClass &emitClass)>;
+
+// ODSDialectHookRegistration provides a global initializer that registers a
+// dialect specific content emission function.
+struct ODSDialectHookRegistration {
+ ODSDialectHookRegistration(llvm::StringRef dialectName,
+ DialectEmitFunction emitFn);
+};
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TABLEGEN_ODSDIALECTHOOK_H_
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index dd5ff353bf9..fc558011fe9 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -46,6 +46,9 @@ public:
// Returns this op's dialect name.
StringRef getDialectName() const;
+ // Returns the dialect of the op.
+ const Dialect &getDialect() const { return dialect; }
+
// Returns the operation name. The name will follow the "<dialect>.<op-name>"
// format if its dialect name is not empty.
std::string getOperationName() const;
@@ -156,14 +159,8 @@ public:
StringRef getExtraClassDeclaration() const;
// Returns the Tablegen definition this operator was constructed from.
- // TODO(antiagainst,zinenko): do not expose the TableGen record, this is a
- // temporary solution to OpEmitter requiring a Record because Operator does
- // not provide enough methods.
const llvm::Record &getDef() const;
- // Returns the dialect of the op.
- const Dialect &getDialect() const { return dialect; }
-
// Prints the contents in this operator to the given `os`. This is used for
// debugging purposes.
void print(llvm::raw_ostream &os) const;
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 4cb5b059705..c22aff17e53 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -14,11 +14,13 @@
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/ODSDialectHook.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -26,10 +28,35 @@
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
-using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
+using llvm::CodeInit;
+using llvm::DefInit;
+using llvm::formatv;
+using llvm::Init;
+using llvm::ListInit;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::StringInit;
+
+//===----------------------------------------------------------------------===//
+// Dialect hook registration
+//===----------------------------------------------------------------------===//
+
+static llvm::ManagedStatic<llvm::StringMap<DialectEmitFunction>> dialectHooks;
+
+ODSDialectHookRegistration::ODSDialectHookRegistration(
+ StringRef dialectName, DialectEmitFunction emitFn) {
+ bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second;
+ assert(inserted && "Multiple ODS hooks for the same dialect!");
+ (void)inserted;
+}
+
+//===----------------------------------------------------------------------===//
+// Static string definitions
+//===----------------------------------------------------------------------===//
+
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "tblgen_arg";
static const char *const builderOpState = "tblgen_state";
@@ -279,6 +306,7 @@ OpEmitter::OpEmitter(const Operator &op)
verifyCtx.withOp("(*this->getOperation())");
genTraits();
+
// Generate C++ code for various op methods. The order here determines the
// methods in the generated file.
genOpAsmInterface();
@@ -294,6 +322,13 @@ OpEmitter::OpEmitter(const Operator &op)
genCanonicalizerDecls();
genFolderDecls();
genOpInterfaceMethods();
+
+ // If a dialect hook is registered for this op's dialect, emit dialect
+ // specific content.
+ auto dialectHookIt = dialectHooks->find(op.getDialectName());
+ if (dialectHookIt != dialectHooks->end()) {
+ dialectHookIt->second(op, opClass);
+ }
}
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpenPOWER on IntegriCloud