summaryrefslogtreecommitdiffstats
path: root/mlir/tools
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-12-27 16:24:33 -0500
committerLei Zhang <antiagainst@google.com>2019-12-27 16:25:09 -0500
commitb30d87a90ba983d76f8a6cd334ac38244bbf9ded (patch)
tree1de94d9458e552ff6c90dadb621f66521659ca1d /mlir/tools
parentc3dbd782f1e0578c7ebc342f2e92f54d9644cff7 (diff)
downloadbcm5719-llvm-b30d87a90ba983d76f8a6cd334ac38244bbf9ded.tar.gz
bcm5719-llvm-b30d87a90ba983d76f8a6cd334ac38244bbf9ded.zip
[mlir][spirv] Add basic definitions for supporting availability
SPIR-V has a few mechanisms to control op availability: version, extension, and capabilities. These mechanisms are considered as different availability classes. This commit introduces basic definitions for modelling SPIR-V availability classes. Specifically, an `Availability` class is added to SPIRVBase.td, along with two subclasses: MinVersion and MaxVersion for versioning. SPV_Op is extended to take a list of `Availability`. Each `Availability` instance carries information for generating op interfaces for the corresponding availability class and also the concrete availability requirements. With the availability spec on ops, we can now auto-generate the op interfaces of all SPIR-V availability classes and also synthesize the op's implementations of these interfaces. The interface generation is done via new TableGen backends -gen-avail-interface-{decls|defs}. The op's implementation is done via -gen-spirv-avail-impls. Differential Revision: https://reviews.llvm.org/D71930
Diffstat (limited to 'mlir/tools')
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt1
-rw-r--r--mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp328
2 files changed, 321 insertions, 8 deletions
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b30d7e39ce8..1281569ef27 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -41,6 +41,7 @@ set(LIBS
MLIRROCDLIR
MLIRSPIRV
MLIRStandardToSPIRVTransforms
+ MLIRSPIRVTestPasses
MLIRSPIRVTransforms
MLIRStandardOps
MLIRStandardToLLVM
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index d65b216e109..639f01458a6 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -13,6 +13,7 @@
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
@@ -44,6 +45,233 @@ using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
+// Availability Wrapper Class
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Wrapper class with helper methods for accessing availability defined in
+// TableGen.
+class Availability {
+public:
+ explicit Availability(const Record *def);
+
+ // Returns the name of the direct TableGen class for this availability
+ // instance.
+ StringRef getClass() const;
+
+ // Returns the generated C++ interface's class name.
+ StringRef getInterfaceClassName() const;
+
+ // Returns the generated C++ interface's description.
+ StringRef getInterfaceDescription() const;
+
+ // Returns the name of the query function insided the generated C++ interface.
+ StringRef getQueryFnName() const;
+
+ // Returns the return type of the query function insided the generated C++
+ // interface.
+ StringRef getQueryFnRetType() const;
+
+ // Returns the code for merging availability requirements.
+ StringRef getMergeActionCode() const;
+
+ // Returns the initializer expression for initializing the final availability
+ // requirements.
+ StringRef getMergeInitializer() const;
+
+ // Returns the C++ type for an availability instance.
+ StringRef getMergeInstanceType() const;
+
+ // Returns the concrete availability instance carried in this case.
+ StringRef getMergeInstance() const;
+
+private:
+ // The TableGen definition of this availability.
+ const llvm::Record *def;
+};
+} // namespace
+
+Availability::Availability(const llvm::Record *def) : def(def) {
+ assert(def->isSubClassOf("Availability") &&
+ "must be subclass of TableGen 'Availability' class");
+}
+
+StringRef Availability::getClass() const {
+ SmallVector<Record *, 1> parentClass;
+ def->getDirectSuperClasses(parentClass);
+ if (parentClass.size() != 1) {
+ PrintFatalError(def->getLoc(),
+ "expected to only have one direct superclass");
+ }
+ return parentClass.front()->getName();
+}
+
+StringRef Availability::getInterfaceClassName() const {
+ return def->getValueAsString("interfaceName");
+}
+
+StringRef Availability::getInterfaceDescription() const {
+ return def->getValueAsString("interfaceDescription");
+}
+
+StringRef Availability::getQueryFnRetType() const {
+ return def->getValueAsString("queryFnRetType");
+}
+
+StringRef Availability::getQueryFnName() const {
+ return def->getValueAsString("queryFnName");
+}
+
+StringRef Availability::getMergeActionCode() const {
+ return def->getValueAsString("mergeAction");
+}
+
+StringRef Availability::getMergeInitializer() const {
+ return def->getValueAsString("initializer");
+}
+
+StringRef Availability::getMergeInstanceType() const {
+ return def->getValueAsString("instanceType");
+}
+
+StringRef Availability::getMergeInstance() const {
+ return def->getValueAsString("instance");
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Definitions AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceDef(const Availability &availability,
+ raw_ostream &os) {
+ StringRef methodName = availability.getQueryFnName();
+ os << availability.getQueryFnRetType() << " "
+ << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
+ << " return getImpl()->" << methodName << "(getOperation());\n"
+ << "}\n";
+}
+
+static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ llvm::emitSourceFileHeader("Availability Interface Definitions", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
+ SmallVector<const Record *, 1> handledClasses;
+ for (const Record *def : defs) {
+ SmallVector<Record *, 1> parent;
+ def->getDirectSuperClasses(parent);
+ if (parent.size() != 1) {
+ PrintFatalError(def->getLoc(),
+ "expected to only have one direct superclass");
+ }
+ if (llvm::is_contained(handledClasses, parent.front()))
+ continue;
+
+ Availability availability(def);
+ emitInterfaceDef(availability, os);
+ handledClasses.push_back(parent.front());
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Declarations AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
+ os << " class Concept {\n"
+ << " public:\n"
+ << " virtual ~Concept() = default;\n"
+ << " virtual " << availability.getQueryFnRetType() << " "
+ << availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n"
+ << " };\n";
+}
+
+static void emitModelDecl(const Availability &availability, raw_ostream &os) {
+ os << " template<typename ConcreteOp>\n";
+ os << " class Model : public Concept {\n"
+ << " public:\n"
+ << " " << availability.getQueryFnRetType() << " "
+ << availability.getQueryFnName()
+ << "(Operation *tblgen_opaque_op) final {\n"
+ << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
+ << " (void)op;\n"
+ // Forward to the method on the concrete operation type.
+ << " return op." << availability.getQueryFnName() << "();\n"
+ << " }\n"
+ << " };\n";
+}
+
+static void emitInterfaceDecl(const Availability &availability,
+ raw_ostream &os) {
+ StringRef interfaceName = availability.getInterfaceClassName();
+ std::string interfaceTraitsName = formatv("{0}Traits", interfaceName);
+
+ // Emit the traits struct containing the concept and model declarations.
+ os << "namespace detail {\n"
+ << "struct " << interfaceTraitsName << " {\n";
+ emitConceptDecl(availability, os);
+ os << '\n';
+ emitModelDecl(availability, os);
+ os << "};\n} // end namespace detail\n\n";
+
+ // Emit the main interface class declaration.
+ os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
+ os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
+ "public:\n"
+ " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
+ interfaceName, interfaceName, interfaceTraitsName);
+
+ // Emit query function declaration.
+ os << " " << availability.getQueryFnRetType() << " "
+ << availability.getQueryFnName() << "();\n";
+ os << "};\n\n";
+}
+
+static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ llvm::emitSourceFileHeader("Availability Interface Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
+ SmallVector<const Record *, 4> handledClasses;
+ for (const Record *def : defs) {
+ SmallVector<Record *, 1> parent;
+ def->getDirectSuperClasses(parent);
+ if (parent.size() != 1) {
+ PrintFatalError(def->getLoc(),
+ "expected to only have one direct superclass");
+ }
+ if (llvm::is_contained(handledClasses, parent.front()))
+ continue;
+
+ Availability avail(def);
+ emitInterfaceDecl(avail, os);
+ handledClasses.push_back(parent.front());
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Hook Registration
+//===----------------------------------------------------------------------===//
+
+// Registers the operation interface generator to mlir-tblgen.
+static mlir::GenRegistration
+ genInterfaceDecls("gen-avail-interface-decls",
+ "Generate availability interface declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitInterfaceDecls(records, os);
+ });
+
+// Registers the operation interface generator to mlir-tblgen.
+static mlir::GenRegistration
+ genInterfaceDefs("gen-avail-interface-defs",
+ "Generate op interface definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitInterfaceDefs(records, os);
+ });
+
+//===----------------------------------------------------------------------===//
// Serialization AutoGen
//===----------------------------------------------------------------------===//
@@ -651,6 +879,17 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
}
//===----------------------------------------------------------------------===//
+// Serialization Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genSerialization(
+ "gen-spirv-serialization",
+ "Generate SPIR-V (de)serialization utilities and functions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitSerializationFns(records, os);
+ });
+
+//===----------------------------------------------------------------------===//
// Op Utils AutoGen
//===----------------------------------------------------------------------===//
@@ -707,19 +946,92 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
}
//===----------------------------------------------------------------------===//
-// Hook Registration
+// Op Utils Hook Registration
//===----------------------------------------------------------------------===//
-static mlir::GenRegistration genSerialization(
- "gen-spirv-serialization",
- "Generate SPIR-V (de)serialization utilities and functions",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitSerializationFns(records, os);
- });
-
static mlir::GenRegistration
genOpUtils("gen-spirv-op-utils",
"Generate SPIR-V operation utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpUtils(records, os);
});
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Availability Impl AutoGen
+//===----------------------------------------------------------------------===//
+
+// Returns the availability spec of the given `def`.
+std::vector<Availability> getAvailabilities(const Record &def) {
+ std::vector<Availability> availabilities;
+ if (auto *availListInit = def.getValueAsListInit("availability")) {
+ availabilities.reserve(availListInit->size());
+ for (auto *availInit : *availListInit)
+ availabilities.emplace_back(
+ llvm::cast<llvm::DefInit>(availInit)->getDef());
+ }
+ return availabilities;
+}
+
+static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
+ mlir::tblgen::FmtContext fctx;
+ fctx.addSubst("overall", "overall");
+
+ std::vector<Availability> opAvailabilities =
+ getAvailabilities(srcOp.getDef());
+
+ // First collect all availablity classes this op should implement.
+ // All availablity instances keep information for the generated interface and
+ // the instance's specific requirement. Here we remember a random instance so
+ // we can get the information regarding the generated interface.
+ llvm::StringMap<Availability> availClasses;
+ for (const Availability &avail : opAvailabilities)
+ availClasses.try_emplace(avail.getClass(), avail);
+
+ // Then generate implementation for each availability class.
+ for (const auto &availClass : availClasses) {
+ StringRef availClassName = availClass.getKey();
+ Availability avail = availClass.getValue();
+
+ // Generate the implementation method signature.
+ os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
+ srcOp.getCppClassName(), avail.getQueryFnName());
+
+ // Create the variable for the final requirement and initialize it.
+ os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(),
+ avail.getMergeInitializer());
+
+ // Update with the op's specific availability spec.
+ for (const Availability &avail : opAvailabilities)
+ if (avail.getClass() == availClassName) {
+ os << " "
+ << tgfmt(avail.getMergeActionCode(),
+ &fctx.addSubst("instance", avail.getMergeInstance()))
+ << ";\n";
+ }
+ os << " return overall;\n";
+ os << "}\n";
+ }
+}
+
+static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
+ for (const auto *def : defs) {
+ Operator op(def);
+ emitAvailabilityImpl(op, os);
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Op Availability Implementation Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+ genOpAvailabilityImpl("gen-spirv-avail-impls",
+ "Generate SPIR-V operation utility definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitAvailabilityImpl(records, os);
+ });
OpenPOWER on IntegriCloud