diff options
Diffstat (limited to 'mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp')
-rw-r--r-- | mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 270 |
1 files changed, 258 insertions, 12 deletions
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 639f01458a6..74a1f6c0042 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" @@ -40,6 +41,7 @@ using llvm::StringRef; using llvm::Twine; using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; +using mlir::tblgen::EnumAttrCase; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; @@ -138,6 +140,20 @@ StringRef Availability::getMergeInstance() const { return def->getValueAsString("instance"); } +// Returns the availability spec of the given `def`. +std::vector<Availability> getAvailabilities(const Record &def) { + std::vector<Availability> availabilities; + + if (def.getValue("availability")) { + std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability"); + availabilities.reserve(availDefs.size()); + for (const Record *avail : availDefs) + availabilities.emplace_back(avail); + } + + return availabilities; +} + //===----------------------------------------------------------------------===// // Availability Interface Definitions AutoGen //===----------------------------------------------------------------------===// @@ -272,6 +288,186 @@ static mlir::GenRegistration }); //===----------------------------------------------------------------------===// +// Enum Availability Query AutoGen +//===----------------------------------------------------------------------===// + +static void emitAvailabilityQueryForIntEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases(); + + // Mapping from availability class name to (enumerant, availablity + // specification) pairs. + llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>> + classCaseMap; + + // Place all availablity specifications to their corresponding + // availablility classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) + classCaseMap[avail.getClass()].push_back({enumerant, avail}); + + for (const auto &classCasePair : classCaseMap) { + Availability avail = classCasePair.getValue().front().second; + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { + EnumAttrCase enumerant = caseSpecPair.first; + Availability avail = caseSpecPair.second; + os << formatv(" case {0}::{1}: return {2}({3});\n", enumName, + enumerant.getSymbol(), avail.getMergeInstanceType(), + avail.getMergeInstance()); + } + os << " default: break;\n"; + os << " }\n" + << " return llvm::None;\n" + << "}\n"; + } +} + +static void emitAvailabilityQueryForBitEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases(); + + // Mapping from availability class name to (enumerant, availablity + // specification) pairs. + llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>> + classCaseMap; + + // Place all availablity specifications to their corresponding + // availablility classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) + classCaseMap[avail.getClass()].push_back({enumerant, avail}); + + for (const auto &classCasePair : classCaseMap) { + Availability avail = classCasePair.getValue().front().second; + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + + os << formatv( + " assert(::llvm::countPopulation(static_cast<{0}>(value)) <= 1" + " && \"cannot have more than one bit set\");\n", + underlyingType); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { + EnumAttrCase enumerant = caseSpecPair.first; + Availability avail = caseSpecPair.second; + os << formatv(" case {0}::{1}: return {2}({3});\n", enumName, + enumerant.getSymbol(), avail.getMergeInstanceType(), + avail.getMergeInstance()); + } + os << " default: break;\n"; + os << " }\n" + << " return llvm::None;\n" + << "}\n"; + } +} + +static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + auto enumerants = enumAttr.getAllCases(); + + llvm::SmallVector<StringRef, 2> namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + llvm::StringSet<> handledClasses; + + // Place all availablity specifications to their corresponding + // availablility classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) { + StringRef className = avail.getClass(); + if (handledClasses.count(className)) + continue; + os << formatv("llvm::Optional<{0}> {1}({2} value);\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + handledClasses.insert(className); + } + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + +static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + for (const auto *def : defs) + emitEnumDecl(*def, os); + + return false; +} + +static void emitEnumDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + llvm::SmallVector<StringRef, 2> namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + if (enumAttr.isBitEnum()) { + emitAvailabilityQueryForBitEnum(enumDef, os); + } else { + emitAvailabilityQueryForIntEnum(enumDef, os); + } + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + os << "\n"; +} + +static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + for (const auto *def : defs) + emitEnumDef(*def, os); + + return false; +} + +//===----------------------------------------------------------------------===// +// Enum Availability Query Hook Registration +//===----------------------------------------------------------------------===// + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumDecls("gen-spirv-enum-avail-decls", + "Generate SPIR-V enum availability declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumDecls(records, os); + }); + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumDefs("gen-spirv-enum-avail-defs", + "Generate SPIR-V enum availability definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumDefs(records, os); + }); + +//===----------------------------------------------------------------------===// // Serialization AutoGen //===----------------------------------------------------------------------===// @@ -960,18 +1156,6 @@ static mlir::GenRegistration // SPIR-V Availability Impl AutoGen //===----------------------------------------------------------------------===// -// Returns the availability spec of the given `def`. -std::vector<Availability> getAvailabilities(const Record &def) { - std::vector<Availability> availabilities; - if (auto *availListInit = def.getValueAsListInit("availability")) { - availabilities.reserve(availListInit->size()); - for (auto *availInit : *availListInit) - availabilities.emplace_back( - llvm::cast<llvm::DefInit>(availInit)->getDef()); - } - return availabilities; -} - static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { mlir::tblgen::FmtContext fctx; fctx.addSubst("overall", "overall"); @@ -986,6 +1170,16 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { llvm::StringMap<Availability> availClasses; for (const Availability &avail : opAvailabilities) availClasses.try_emplace(avail.getClass(), avail); + for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { + const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr); + if (!enumAttr) + continue; + + for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const Availability &caseAvail : + getAvailabilities(enumerant.getDef())) + availClasses.try_emplace(caseAvail.getClass(), caseAvail); + } // Then generate implementation for each availability class. for (const auto &availClass : availClasses) { @@ -1008,6 +1202,58 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { &fctx.addSubst("instance", avail.getMergeInstance())) << ";\n"; } + + // Update with enum attributes' specific availability spec. + for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { + const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr); + if (!enumAttr) + continue; + + // (enumerant, availablity specification) pairs for this availability + // class. + SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs; + + // Collect all cases' availability specs. + for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const Availability &caseAvail : + getAvailabilities(enumerant.getDef())) + if (availClassName == caseAvail.getClass()) + caseSpecs.push_back({enumerant, caseAvail}); + + // If this attribute kind does not have any availablity spec from any of + // its cases, no more work to do. + if (caseSpecs.empty()) + continue; + + if (enumAttr->isBitEnum()) { + // For BitEnumAttr, we need to iterate over each bit to query its + // availability spec. + os << formatv(" for (unsigned i = 0; " + "i < std::numeric_limits<{0}>::digits; ++i) {{\n", + enumAttr->getUnderlyingType()); + os << formatv(" {0}::{1} attrVal = this->{2}() & " + "static_cast<{0}::{1}>(1 << i);\n", + enumAttr->getCppNamespace(), enumAttr->getEnumClassName(), + namedAttr.name); + os << formatv(" if (static_cast<{0}>(attrVal) == 0) continue;\n", + enumAttr->getUnderlyingType()); + } else { + // For IntEnumAttr, we just need to query the value as a whole. + os << " {\n"; + os << formatv(" auto attrVal = this->{0}();\n", namedAttr.name); + } + os << formatv(" auto instance = {0}::{1}(attrVal);\n", + enumAttr->getCppNamespace(), avail.getQueryFnName()); + os << " if (instance) " + // TODO(antiagainst): use `avail.getMergeCode()` here once ODS supports + // dialect-specific contents so that we can use not implementing the + // availability interface as indication of no requirements. + << tgfmt(caseSpecs.front().second.getMergeActionCode(), + &fctx.addSubst("instance", "*instance")) + << ";\n"; + os << " }\n"; + } + os << " return overall;\n"; os << "}\n"; } |