summaryrefslogtreecommitdiffstats
path: root/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp270
1 files changed, 258 insertions, 12 deletions
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 639f01458a6..74a1f6c0042 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -21,6 +21,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
@@ -40,6 +41,7 @@ using llvm::StringRef;
using llvm::Twine;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
+using mlir::tblgen::EnumAttrCase;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
@@ -138,6 +140,20 @@ StringRef Availability::getMergeInstance() const {
return def->getValueAsString("instance");
}
+// Returns the availability spec of the given `def`.
+std::vector<Availability> getAvailabilities(const Record &def) {
+ std::vector<Availability> availabilities;
+
+ if (def.getValue("availability")) {
+ std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
+ availabilities.reserve(availDefs.size());
+ for (const Record *avail : availDefs)
+ availabilities.emplace_back(avail);
+ }
+
+ return availabilities;
+}
+
//===----------------------------------------------------------------------===//
// Availability Interface Definitions AutoGen
//===----------------------------------------------------------------------===//
@@ -272,6 +288,186 @@ static mlir::GenRegistration
});
//===----------------------------------------------------------------------===//
+// Enum Availability Query AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
+ raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef enumName = enumAttr.getEnumClassName();
+ std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
+
+ // Mapping from availability class name to (enumerant, availablity
+ // specification) pairs.
+ llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
+ classCaseMap;
+
+ // Place all availablity specifications to their corresponding
+ // availablility classes.
+ for (const EnumAttrCase &enumerant : enumerants)
+ for (const Availability &avail : getAvailabilities(enumerant.getDef()))
+ classCaseMap[avail.getClass()].push_back({enumerant, avail});
+
+ for (const auto &classCasePair : classCaseMap) {
+ Availability avail = classCasePair.getValue().front().second;
+
+ os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n",
+ avail.getMergeInstanceType(), avail.getQueryFnName(),
+ enumName);
+
+ os << " switch (value) {\n";
+ for (const auto &caseSpecPair : classCasePair.getValue()) {
+ EnumAttrCase enumerant = caseSpecPair.first;
+ Availability avail = caseSpecPair.second;
+ os << formatv(" case {0}::{1}: return {2}({3});\n", enumName,
+ enumerant.getSymbol(), avail.getMergeInstanceType(),
+ avail.getMergeInstance());
+ }
+ os << " default: break;\n";
+ os << " }\n"
+ << " return llvm::None;\n"
+ << "}\n";
+ }
+}
+
+static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
+ raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef enumName = enumAttr.getEnumClassName();
+ std::string underlyingType = enumAttr.getUnderlyingType();
+ std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
+
+ // Mapping from availability class name to (enumerant, availablity
+ // specification) pairs.
+ llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
+ classCaseMap;
+
+ // Place all availablity specifications to their corresponding
+ // availablility classes.
+ for (const EnumAttrCase &enumerant : enumerants)
+ for (const Availability &avail : getAvailabilities(enumerant.getDef()))
+ classCaseMap[avail.getClass()].push_back({enumerant, avail});
+
+ for (const auto &classCasePair : classCaseMap) {
+ Availability avail = classCasePair.getValue().front().second;
+
+ os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n",
+ avail.getMergeInstanceType(), avail.getQueryFnName(),
+ enumName);
+
+ os << formatv(
+ " assert(::llvm::countPopulation(static_cast<{0}>(value)) <= 1"
+ " && \"cannot have more than one bit set\");\n",
+ underlyingType);
+
+ os << " switch (value) {\n";
+ for (const auto &caseSpecPair : classCasePair.getValue()) {
+ EnumAttrCase enumerant = caseSpecPair.first;
+ Availability avail = caseSpecPair.second;
+ os << formatv(" case {0}::{1}: return {2}({3});\n", enumName,
+ enumerant.getSymbol(), avail.getMergeInstanceType(),
+ avail.getMergeInstance());
+ }
+ os << " default: break;\n";
+ os << " }\n"
+ << " return llvm::None;\n"
+ << "}\n";
+ }
+}
+
+static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef enumName = enumAttr.getEnumClassName();
+ StringRef cppNamespace = enumAttr.getCppNamespace();
+ auto enumerants = enumAttr.getAllCases();
+
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(cppNamespace, namespaces, "::");
+
+ for (auto ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ llvm::StringSet<> handledClasses;
+
+ // Place all availablity specifications to their corresponding
+ // availablility classes.
+ for (const EnumAttrCase &enumerant : enumerants)
+ for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
+ StringRef className = avail.getClass();
+ if (handledClasses.count(className))
+ continue;
+ os << formatv("llvm::Optional<{0}> {1}({2} value);\n",
+ avail.getMergeInstanceType(), avail.getQueryFnName(),
+ enumName);
+ handledClasses.insert(className);
+ }
+
+ for (auto ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
+}
+
+static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
+ for (const auto *def : defs)
+ emitEnumDecl(*def, os);
+
+ return false;
+}
+
+static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef cppNamespace = enumAttr.getCppNamespace();
+
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(cppNamespace, namespaces, "::");
+
+ for (auto ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ if (enumAttr.isBitEnum()) {
+ emitAvailabilityQueryForBitEnum(enumDef, os);
+ } else {
+ emitAvailabilityQueryForIntEnum(enumDef, os);
+ }
+
+ for (auto ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
+ os << "\n";
+}
+
+static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
+ for (const auto *def : defs)
+ emitEnumDef(*def, os);
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Enum Availability Query Hook Registration
+//===----------------------------------------------------------------------===//
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+ genEnumDecls("gen-spirv-enum-avail-decls",
+ "Generate SPIR-V enum availability declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitEnumDecls(records, os);
+ });
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+ genEnumDefs("gen-spirv-enum-avail-defs",
+ "Generate SPIR-V enum availability definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitEnumDefs(records, os);
+ });
+
+//===----------------------------------------------------------------------===//
// Serialization AutoGen
//===----------------------------------------------------------------------===//
@@ -960,18 +1156,6 @@ static mlir::GenRegistration
// SPIR-V Availability Impl AutoGen
//===----------------------------------------------------------------------===//
-// Returns the availability spec of the given `def`.
-std::vector<Availability> getAvailabilities(const Record &def) {
- std::vector<Availability> availabilities;
- if (auto *availListInit = def.getValueAsListInit("availability")) {
- availabilities.reserve(availListInit->size());
- for (auto *availInit : *availListInit)
- availabilities.emplace_back(
- llvm::cast<llvm::DefInit>(availInit)->getDef());
- }
- return availabilities;
-}
-
static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
mlir::tblgen::FmtContext fctx;
fctx.addSubst("overall", "overall");
@@ -986,6 +1170,16 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
llvm::StringMap<Availability> availClasses;
for (const Availability &avail : opAvailabilities)
availClasses.try_emplace(avail.getClass(), avail);
+ for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
+ const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
+ if (!enumAttr)
+ continue;
+
+ for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
+ for (const Availability &caseAvail :
+ getAvailabilities(enumerant.getDef()))
+ availClasses.try_emplace(caseAvail.getClass(), caseAvail);
+ }
// Then generate implementation for each availability class.
for (const auto &availClass : availClasses) {
@@ -1008,6 +1202,58 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
&fctx.addSubst("instance", avail.getMergeInstance()))
<< ";\n";
}
+
+ // Update with enum attributes' specific availability spec.
+ for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
+ const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
+ if (!enumAttr)
+ continue;
+
+ // (enumerant, availablity specification) pairs for this availability
+ // class.
+ SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
+
+ // Collect all cases' availability specs.
+ for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
+ for (const Availability &caseAvail :
+ getAvailabilities(enumerant.getDef()))
+ if (availClassName == caseAvail.getClass())
+ caseSpecs.push_back({enumerant, caseAvail});
+
+ // If this attribute kind does not have any availablity spec from any of
+ // its cases, no more work to do.
+ if (caseSpecs.empty())
+ continue;
+
+ if (enumAttr->isBitEnum()) {
+ // For BitEnumAttr, we need to iterate over each bit to query its
+ // availability spec.
+ os << formatv(" for (unsigned i = 0; "
+ "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
+ enumAttr->getUnderlyingType());
+ os << formatv(" {0}::{1} attrVal = this->{2}() & "
+ "static_cast<{0}::{1}>(1 << i);\n",
+ enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
+ namedAttr.name);
+ os << formatv(" if (static_cast<{0}>(attrVal) == 0) continue;\n",
+ enumAttr->getUnderlyingType());
+ } else {
+ // For IntEnumAttr, we just need to query the value as a whole.
+ os << " {\n";
+ os << formatv(" auto attrVal = this->{0}();\n", namedAttr.name);
+ }
+ os << formatv(" auto instance = {0}::{1}(attrVal);\n",
+ enumAttr->getCppNamespace(), avail.getQueryFnName());
+ os << " if (instance) "
+ // TODO(antiagainst): use `avail.getMergeCode()` here once ODS supports
+ // dialect-specific contents so that we can use not implementing the
+ // availability interface as indication of no requirements.
+ << tgfmt(caseSpecs.front().second.getMergeActionCode(),
+ &fctx.addSubst("instance", "*instance"))
+ << ";\n";
+ os << " }\n";
+ }
+
os << " return overall;\n";
os << "}\n";
}
OpenPOWER on IntegriCloud