diff options
Diffstat (limited to 'mlir/lib/IR/Dialect.cpp')
-rw-r--r-- | mlir/lib/IR/Dialect.cpp | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp new file mode 100644 index 00000000000..b2485a368fd --- /dev/null +++ b/mlir/lib/IR/Dialect.cpp @@ -0,0 +1,156 @@ +//===- Dialect.cpp - Dialect implementation -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectHooks.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Regex.h" + +using namespace mlir; +using namespace detail; + +DialectAsmParser::~DialectAsmParser() {} + +//===----------------------------------------------------------------------===// +// Dialect Registration +//===----------------------------------------------------------------------===// + +// Registry for all dialect allocation functions. +static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>> + dialectRegistry; + +// Registry for functions that set dialect hooks. +static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>> + dialectHooksRegistry; + +/// Registers a specific dialect creation function with the system, typically +/// used through the DialectRegistration template. +void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) { + assert(function && + "Attempting to register an empty dialect initialize function"); + dialectRegistry->push_back(function); +} + +/// Registers a function to set specific hooks for a specific dialect, typically +/// used through the DialectHooksRegistration template. +void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) { + assert( + function && + "Attempting to register an empty dialect hooks initialization function"); + + dialectHooksRegistry->push_back(function); +} + +/// Registers all dialects and their const folding hooks with the specified +/// MLIRContext. +void mlir::registerAllDialects(MLIRContext *context) { + for (const auto &fn : *dialectRegistry) + fn(context); + for (const auto &fn : *dialectHooksRegistry) { + fn(context); + } +} + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +Dialect::Dialect(StringRef name, MLIRContext *context) + : name(name), context(context) { + assert(isValidNamespace(name) && "invalid dialect namespace"); + registerDialect(context); +} + +Dialect::~Dialect() {} + +/// Verify an attribute from this dialect on the argument at 'argIndex' for +/// the region at 'regionIndex' on the given operation. Returns failure if +/// the verification failed, success otherwise. This hook may optionally be +/// invoked from any operation containing a region. +LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, + NamedAttribute) { + return success(); +} + +/// Verify an attribute from this dialect on the result at 'resultIndex' for +/// the region at 'regionIndex' on the given operation. Returns failure if +/// the verification failed, success otherwise. This hook may optionally be +/// invoked from any operation containing a region. +LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, + unsigned, NamedAttribute) { + return success(); +} + +/// Parse an attribute registered to this dialect. +Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { + parser.emitError(parser.getNameLoc()) + << "dialect '" << getNamespace() + << "' provides no attribute parsing hook"; + return Attribute(); +} + +/// Parse a type registered to this dialect. +Type Dialect::parseType(DialectAsmParser &parser) const { + // If this dialect allows unknown types, then represent this with OpaqueType. + if (allowsUnknownTypes()) { + auto ns = Identifier::get(getNamespace(), getContext()); + return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); + } + + parser.emitError(parser.getNameLoc()) + << "dialect '" << getNamespace() << "' provides no type parsing hook"; + return Type(); +} + +/// Utility function that returns if the given string is a valid dialect +/// namespace. +bool Dialect::isValidNamespace(StringRef str) { + if (str.empty()) + return true; + llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); + return dialectNameRegex.match(str); +} + +/// Register a set of dialect interfaces with this dialect instance. +void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { + auto it = registeredInterfaces.try_emplace(interface->getID(), + std::move(interface)); + (void)it; + assert(it.second && "interface kind has already been registered"); +} + +//===----------------------------------------------------------------------===// +// Dialect Interface +//===----------------------------------------------------------------------===// + +DialectInterface::~DialectInterface() {} + +DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( + MLIRContext *ctx, ClassID *interfaceKind) { + for (auto *dialect : ctx->getRegisteredDialects()) { + if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { + interfaces.insert(interface); + orderedInterfaces.push_back(interface); + } + } +} + +DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} + +/// Get the interface for the dialect of given operation, or null if one +/// is not registered. +const DialectInterface * +DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { + return getInterfaceFor(op->getDialect()); +} |