diff options
-rw-r--r-- | mlir/include/mlir/TableGen/ODSDialectHook.h | 42 | ||||
-rw-r--r-- | mlir/include/mlir/TableGen/Operator.h | 9 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 37 |
3 files changed, 81 insertions, 7 deletions
diff --git a/mlir/include/mlir/TableGen/ODSDialectHook.h b/mlir/include/mlir/TableGen/ODSDialectHook.h new file mode 100644 index 00000000000..9d1ea3b6857 --- /dev/null +++ b/mlir/include/mlir/TableGen/ODSDialectHook.h @@ -0,0 +1,42 @@ +//===- ODSDialectHook.h - Dialect customization hooks into ODS --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines ODS customization hooks for dialects to programmatically +// emit dialect specific contents in ODS C++ code emission. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ODSDIALECTHOOK_H_ +#define MLIR_TABLEGEN_ODSDIALECTHOOK_H_ + +#include <functional> + +namespace llvm { +class StringRef; +} + +namespace mlir { +namespace tblgen { +class Operator; +class OpClass; + +// The emission function for dialect specific content. It takes in an Operator +// and updates the OpClass accordingly. +using DialectEmitFunction = + std::function<void(const Operator &srcOp, OpClass &emitClass)>; + +// ODSDialectHookRegistration provides a global initializer that registers a +// dialect specific content emission function. +struct ODSDialectHookRegistration { + ODSDialectHookRegistration(llvm::StringRef dialectName, + DialectEmitFunction emitFn); +}; +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_ODSDIALECTHOOK_H_ diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index dd5ff353bf9..fc558011fe9 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -46,6 +46,9 @@ public: // Returns this op's dialect name. StringRef getDialectName() const; + // Returns the dialect of the op. + const Dialect &getDialect() const { return dialect; } + // Returns the operation name. The name will follow the "<dialect>.<op-name>" // format if its dialect name is not empty. std::string getOperationName() const; @@ -156,14 +159,8 @@ public: StringRef getExtraClassDeclaration() const; // Returns the Tablegen definition this operator was constructed from. - // TODO(antiagainst,zinenko): do not expose the TableGen record, this is a - // temporary solution to OpEmitter requiring a Record because Operator does - // not provide enough methods. const llvm::Record &getDef() const; - // Returns the dialect of the op. - const Dialect &getDialect() const { return dialect; } - // Prints the contents in this operator to the given `os`. This is used for // debugging purposes. void print(llvm::raw_ostream &os) const; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 4cb5b059705..c22aff17e53 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -14,11 +14,13 @@ #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/ODSDialectHook.h" #include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -26,10 +28,35 @@ #define DEBUG_TYPE "mlir-tblgen-opdefgen" -using namespace llvm; using namespace mlir; using namespace mlir::tblgen; +using llvm::CodeInit; +using llvm::DefInit; +using llvm::formatv; +using llvm::Init; +using llvm::ListInit; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringInit; + +//===----------------------------------------------------------------------===// +// Dialect hook registration +//===----------------------------------------------------------------------===// + +static llvm::ManagedStatic<llvm::StringMap<DialectEmitFunction>> dialectHooks; + +ODSDialectHookRegistration::ODSDialectHookRegistration( + StringRef dialectName, DialectEmitFunction emitFn) { + bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second; + assert(inserted && "Multiple ODS hooks for the same dialect!"); + (void)inserted; +} + +//===----------------------------------------------------------------------===// +// Static string definitions +//===----------------------------------------------------------------------===// + static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "tblgen_arg"; static const char *const builderOpState = "tblgen_state"; @@ -279,6 +306,7 @@ OpEmitter::OpEmitter(const Operator &op) verifyCtx.withOp("(*this->getOperation())"); genTraits(); + // Generate C++ code for various op methods. The order here determines the // methods in the generated file. genOpAsmInterface(); @@ -294,6 +322,13 @@ OpEmitter::OpEmitter(const Operator &op) genCanonicalizerDecls(); genFolderDecls(); genOpInterfaceMethods(); + + // If a dialect hook is registered for this op's dialect, emit dialect + // specific content. + auto dialectHookIt = dialectHooks->find(op.getDialectName()); + if (dialectHookIt != dialectHooks->end()) { + dialectHookIt->second(op, opClass); + } } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { |