diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-12-27 16:24:33 -0500 |
|---|---|---|
| committer | Lei Zhang <antiagainst@google.com> | 2019-12-27 16:25:09 -0500 |
| commit | b30d87a90ba983d76f8a6cd334ac38244bbf9ded (patch) | |
| tree | 1de94d9458e552ff6c90dadb621f66521659ca1d /mlir/tools | |
| parent | c3dbd782f1e0578c7ebc342f2e92f54d9644cff7 (diff) | |
| download | bcm5719-llvm-b30d87a90ba983d76f8a6cd334ac38244bbf9ded.tar.gz bcm5719-llvm-b30d87a90ba983d76f8a6cd334ac38244bbf9ded.zip | |
[mlir][spirv] Add basic definitions for supporting availability
SPIR-V has a few mechanisms to control op availability: version,
extension, and capabilities. These mechanisms are considered as
different availability classes.
This commit introduces basic definitions for modelling SPIR-V
availability classes. Specifically, an `Availability` class is
added to SPIRVBase.td, along with two subclasses: MinVersion
and MaxVersion for versioning. SPV_Op is extended to take a
list of `Availability`. Each `Availability` instance carries
information for generating op interfaces for the corresponding
availability class and also the concrete availability
requirements.
With the availability spec on ops, we can now auto-generate the
op interfaces of all SPIR-V availability classes and also
synthesize the op's implementations of these interfaces. The
interface generation is done via new TableGen backends
-gen-avail-interface-{decls|defs}. The op's implementation is
done via -gen-spirv-avail-impls.
Differential Revision: https://reviews.llvm.org/D71930
Diffstat (limited to 'mlir/tools')
| -rw-r--r-- | mlir/tools/mlir-opt/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 328 |
2 files changed, 321 insertions, 8 deletions
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index b30d7e39ce8..1281569ef27 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -41,6 +41,7 @@ set(LIBS MLIRROCDLIR MLIRSPIRV MLIRStandardToSPIRVTransforms + MLIRSPIRVTestPasses MLIRSPIRVTransforms MLIRStandardOps MLIRStandardToLLVM diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index d65b216e109..639f01458a6 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -13,6 +13,7 @@ #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/Sequence.h" @@ -44,6 +45,233 @@ using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; //===----------------------------------------------------------------------===// +// Availability Wrapper Class +//===----------------------------------------------------------------------===// + +namespace { +// Wrapper class with helper methods for accessing availability defined in +// TableGen. +class Availability { +public: + explicit Availability(const Record *def); + + // Returns the name of the direct TableGen class for this availability + // instance. + StringRef getClass() const; + + // Returns the generated C++ interface's class name. + StringRef getInterfaceClassName() const; + + // Returns the generated C++ interface's description. + StringRef getInterfaceDescription() const; + + // Returns the name of the query function insided the generated C++ interface. + StringRef getQueryFnName() const; + + // Returns the return type of the query function insided the generated C++ + // interface. + StringRef getQueryFnRetType() const; + + // Returns the code for merging availability requirements. + StringRef getMergeActionCode() const; + + // Returns the initializer expression for initializing the final availability + // requirements. + StringRef getMergeInitializer() const; + + // Returns the C++ type for an availability instance. + StringRef getMergeInstanceType() const; + + // Returns the concrete availability instance carried in this case. + StringRef getMergeInstance() const; + +private: + // The TableGen definition of this availability. + const llvm::Record *def; +}; +} // namespace + +Availability::Availability(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Availability") && + "must be subclass of TableGen 'Availability' class"); +} + +StringRef Availability::getClass() const { + SmallVector<Record *, 1> parentClass; + def->getDirectSuperClasses(parentClass); + if (parentClass.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + return parentClass.front()->getName(); +} + +StringRef Availability::getInterfaceClassName() const { + return def->getValueAsString("interfaceName"); +} + +StringRef Availability::getInterfaceDescription() const { + return def->getValueAsString("interfaceDescription"); +} + +StringRef Availability::getQueryFnRetType() const { + return def->getValueAsString("queryFnRetType"); +} + +StringRef Availability::getQueryFnName() const { + return def->getValueAsString("queryFnName"); +} + +StringRef Availability::getMergeActionCode() const { + return def->getValueAsString("mergeAction"); +} + +StringRef Availability::getMergeInitializer() const { + return def->getValueAsString("initializer"); +} + +StringRef Availability::getMergeInstanceType() const { + return def->getValueAsString("instanceType"); +} + +StringRef Availability::getMergeInstance() const { + return def->getValueAsString("instance"); +} + +//===----------------------------------------------------------------------===// +// Availability Interface Definitions AutoGen +//===----------------------------------------------------------------------===// + +static void emitInterfaceDef(const Availability &availability, + raw_ostream &os) { + StringRef methodName = availability.getQueryFnName(); + os << availability.getQueryFnRetType() << " " + << availability.getInterfaceClassName() << "::" << methodName << "() {\n" + << " return getImpl()->" << methodName << "(getOperation());\n" + << "}\n"; +} + +static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Availability Interface Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); + SmallVector<const Record *, 1> handledClasses; + for (const Record *def : defs) { + SmallVector<Record *, 1> parent; + def->getDirectSuperClasses(parent); + if (parent.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + if (llvm::is_contained(handledClasses, parent.front())) + continue; + + Availability availability(def); + emitInterfaceDef(availability, os); + handledClasses.push_back(parent.front()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Availability Interface Declarations AutoGen +//===----------------------------------------------------------------------===// + +static void emitConceptDecl(const Availability &availability, raw_ostream &os) { + os << " class Concept {\n" + << " public:\n" + << " virtual ~Concept() = default;\n" + << " virtual " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n" + << " };\n"; +} + +static void emitModelDecl(const Availability &availability, raw_ostream &os) { + os << " template<typename ConcreteOp>\n"; + os << " class Model : public Concept {\n" + << " public:\n" + << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() + << "(Operation *tblgen_opaque_op) final {\n" + << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n" + << " (void)op;\n" + // Forward to the method on the concrete operation type. + << " return op." << availability.getQueryFnName() << "();\n" + << " }\n" + << " };\n"; +} + +static void emitInterfaceDecl(const Availability &availability, + raw_ostream &os) { + StringRef interfaceName = availability.getInterfaceClassName(); + std::string interfaceTraitsName = formatv("{0}Traits", interfaceName); + + // Emit the traits struct containing the concept and model declarations. + os << "namespace detail {\n" + << "struct " << interfaceTraitsName << " {\n"; + emitConceptDecl(availability, os); + os << '\n'; + emitModelDecl(availability, os); + os << "};\n} // end namespace detail\n\n"; + + // Emit the main interface class declaration. + os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n"; + os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" + "public:\n" + " using OpInterface<{1}, detail::{2}>::OpInterface;\n", + interfaceName, interfaceName, interfaceTraitsName); + + // Emit query function declaration. + os << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() << "();\n"; + os << "};\n\n"; +} + +static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Availability Interface Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); + SmallVector<const Record *, 4> handledClasses; + for (const Record *def : defs) { + SmallVector<Record *, 1> parent; + def->getDirectSuperClasses(parent); + if (parent.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + if (llvm::is_contained(handledClasses, parent.front())) + continue; + + Availability avail(def); + emitInterfaceDecl(avail, os); + handledClasses.push_back(parent.front()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Availability Interface Hook Registration +//===----------------------------------------------------------------------===// + +// Registers the operation interface generator to mlir-tblgen. +static mlir::GenRegistration + genInterfaceDecls("gen-avail-interface-decls", + "Generate availability interface declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitInterfaceDecls(records, os); + }); + +// Registers the operation interface generator to mlir-tblgen. +static mlir::GenRegistration + genInterfaceDefs("gen-avail-interface-defs", + "Generate op interface definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitInterfaceDefs(records, os); + }); + +//===----------------------------------------------------------------------===// // Serialization AutoGen //===----------------------------------------------------------------------===// @@ -651,6 +879,17 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, } //===----------------------------------------------------------------------===// +// Serialization Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration genSerialization( + "gen-spirv-serialization", + "Generate SPIR-V (de)serialization utilities and functions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitSerializationFns(records, os); + }); + +//===----------------------------------------------------------------------===// // Op Utils AutoGen //===----------------------------------------------------------------------===// @@ -707,19 +946,92 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { } //===----------------------------------------------------------------------===// -// Hook Registration +// Op Utils Hook Registration //===----------------------------------------------------------------------===// -static mlir::GenRegistration genSerialization( - "gen-spirv-serialization", - "Generate SPIR-V (de)serialization utilities and functions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitSerializationFns(records, os); - }); - static mlir::GenRegistration genOpUtils("gen-spirv-op-utils", "Generate SPIR-V operation utility definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitOpUtils(records, os); }); + +//===----------------------------------------------------------------------===// +// SPIR-V Availability Impl AutoGen +//===----------------------------------------------------------------------===// + +// Returns the availability spec of the given `def`. +std::vector<Availability> getAvailabilities(const Record &def) { + std::vector<Availability> availabilities; + if (auto *availListInit = def.getValueAsListInit("availability")) { + availabilities.reserve(availListInit->size()); + for (auto *availInit : *availListInit) + availabilities.emplace_back( + llvm::cast<llvm::DefInit>(availInit)->getDef()); + } + return availabilities; +} + +static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { + mlir::tblgen::FmtContext fctx; + fctx.addSubst("overall", "overall"); + + std::vector<Availability> opAvailabilities = + getAvailabilities(srcOp.getDef()); + + // First collect all availablity classes this op should implement. + // All availablity instances keep information for the generated interface and + // the instance's specific requirement. Here we remember a random instance so + // we can get the information regarding the generated interface. + llvm::StringMap<Availability> availClasses; + for (const Availability &avail : opAvailabilities) + availClasses.try_emplace(avail.getClass(), avail); + + // Then generate implementation for each availability class. + for (const auto &availClass : availClasses) { + StringRef availClassName = availClass.getKey(); + Availability avail = availClass.getValue(); + + // Generate the implementation method signature. + os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(), + srcOp.getCppClassName(), avail.getQueryFnName()); + + // Create the variable for the final requirement and initialize it. + os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(), + avail.getMergeInitializer()); + + // Update with the op's specific availability spec. + for (const Availability &avail : opAvailabilities) + if (avail.getClass() == availClassName) { + os << " " + << tgfmt(avail.getMergeActionCode(), + &fctx.addSubst("instance", avail.getMergeInstance())) + << ";\n"; + } + os << " return overall;\n"; + os << "}\n"; + } +} + +static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); + for (const auto *def : defs) { + Operator op(def); + emitAvailabilityImpl(op, os); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Op Availability Implementation Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genOpAvailabilityImpl("gen-spirv-avail-impls", + "Generate SPIR-V operation utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitAvailabilityImpl(records, os); + }); |

