diff options
Diffstat (limited to 'mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp')
-rw-r--r-- | mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 725 |
1 files changed, 725 insertions, 0 deletions
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp new file mode 100644 index 00000000000..d65b216e109 --- /dev/null +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -0,0 +1,725 @@ +//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SPIRVSerializationGen generates common utility functions for SPIR-V +// serialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/StringExtras.h" +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using llvm::ArrayRef; +using llvm::formatv; +using llvm::raw_ostream; +using llvm::raw_string_ostream; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::SmallVector; +using llvm::SMLoc; +using llvm::StringMap; +using llvm::StringRef; +using llvm::Twine; +using mlir::tblgen::Attribute; +using mlir::tblgen::EnumAttr; +using mlir::tblgen::NamedAttribute; +using mlir::tblgen::NamedTypeConstraint; +using mlir::tblgen::Operator; + +//===----------------------------------------------------------------------===// +// Serialization AutoGen +//===----------------------------------------------------------------------===// + +// Writes the following function to `os`: +// inline uint32_t getOpcode(<op-class-name>) { return <opcode>; } +static void emitGetOpcodeFunction(const Record *record, Operator const &op, + raw_ostream &os) { + os << formatv("template <> constexpr inline ::mlir::spirv::Opcode " + "getOpcode<{0}>() {{\n", + op.getQualCppClassName()); + os << formatv(" return ::mlir::spirv::Opcode::{0};\n", + record->getValueAsString("spirvOpName")); + os << "}\n"; +} + +/// Forward declaration of function to return the SPIR-V opcode corresponding to +/// an operation. This function will be generated for all SPV_Op instances that +/// have hasOpcode = 1. +static void declareOpcodeFn(raw_ostream &os) { + os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode " + "getOpcode();\n"; +} + +/// Generates code to serialize attributes of a SPV_Op `op` into `os`. The +/// generates code extracts the attribute with name `attrName` from +/// `operandList` of `op`. +static void emitAttributeSerialization(const Attribute &attr, + ArrayRef<SMLoc> loc, StringRef tabs, + StringRef opVar, StringRef operandList, + StringRef attrName, raw_ostream &os) { + os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName); + os << tabs << "if (attr) {\n"; + if (attr.getAttrDefName() == "SPV_ScopeAttr" || + attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + os << tabs + << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), " + "attr.cast<IntegerAttr>()));\n", + operandList, opVar); + } else if (attr.getAttrDefName() == "I32ArrayAttr") { + // Serialize all the elements of the array + os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n"; + os << tabs + << formatv(" {0}.push_back(static_cast<uint32_t>(" + "attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n", + operandList); + os << tabs << " }\n"; + } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { + os << tabs + << formatv(" {0}.push_back(static_cast<uint32_t>(" + "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n", + operandList); + } else { + PrintFatalError( + loc, + llvm::Twine( + "unhandled attribute type in SPIR-V serialization generation : '") + + attr.getAttrDefName() + llvm::Twine("'")); + } + os << tabs << "}\n"; +} + +/// Generates code to serialize the operands of a SPV_Op `op` into `os`. The +/// generated queries the SSA-ID if operand is a SSA-Value, or serializes the +/// attributes. The `operands` vector is updated appropriately. `elidedAttrs` +/// updated as well to include the serialized attributes. +static void emitOperandSerialization(const Operator &op, ArrayRef<SMLoc> loc, + StringRef tabs, StringRef opVar, + StringRef operands, StringRef elidedAttrs, + raw_ostream &os) { + auto operandNum = 0; + for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); + os << tabs << "{\n"; + if (argument.is<NamedTypeConstraint *>()) { + os << tabs + << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar, + operandNum); + os << tabs << " auto argID = getValueID(arg);\n"; + os << tabs << " if (!argID) {\n"; + os << tabs + << formatv(" return emitError({0}.getLoc(), " + "\"operand #{1} has a use before def\");\n", + opVar, operandNum); + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(argID);\n", operands); + os << " }\n"; + operandNum++; + } else { + auto attr = argument.get<NamedAttribute *>(); + auto newtabs = tabs.str() + " "; + emitAttributeSerialization( + (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), + loc, newtabs, opVar, operands, attr->name, os); + os << newtabs + << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name); + } + os << tabs << "}\n"; + } +} + +/// Generates code to serializes the result of SPV_Op `op` into `os`. The +/// generated gets the ID for the type of the result (if any), the SSA-ID of +/// the result and updates `resultID` with the SSA-ID. +static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc, + StringRef tabs, StringRef opVar, + StringRef operands, StringRef resultID, + raw_ostream &os) { + if (op.getNumResults() == 1) { + StringRef resultTypeID("resultTypeID"); + os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID); + os << tabs + << formatv( + "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n", + opVar, resultTypeID); + os << tabs << " return failure();\n"; + os << tabs << "}\n"; + os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID); + // Create an SSA result <id> for the op + os << tabs << formatv("{0} = getNextID();\n", resultID); + os << tabs + << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID); + os << tabs << formatv("{0}.push_back({1});\n", operands, resultID); + } else if (op.getNumResults() != 0) { + PrintFatalError(loc, "SPIR-V ops can only have zero or one result"); + } +} + +/// Generates code to serialize attributes of SPV_Op `op` that become +/// decorations on the `resultID` of the serialized operation `opVar` in the +/// SPIR-V binary. +static void emitDecorationSerialization(const Operator &op, StringRef tabs, + StringRef opVar, StringRef elidedAttrs, + StringRef resultID, raw_ostream &os) { + if (op.getNumResults() == 1) { + // All non-argument attributes translated into OpDecorate instruction + os << tabs << formatv("for (auto attr : {0}.getAttrs()) {{\n", opVar); + os << tabs + << formatv(" if (llvm::any_of({0}, [&](StringRef elided)", elidedAttrs); + os << " {return attr.first.is(elided);})) {\n"; + os << tabs << " continue;\n"; + os << tabs << " }\n"; + os << tabs + << formatv( + " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n", + opVar, resultID); + os << tabs << " return failure();\n"; + os << tabs << " }\n"; + os << tabs << "}\n"; + } +} + +/// Generates code to serialize an SPV_Op `op` into `os`. +static void emitSerializationFunction(const Record *attrClass, + const Record *record, const Operator &op, + raw_ostream &os) { + // If the record has 'autogenSerialization' set to 0, nothing to do + if (!record->getValueAsBit("autogenSerialization")) { + return; + } + StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"), + resultID("resultID"); + os << formatv( + "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n", + op.getQualCppClassName(), opVar); + os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands); + os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs); + + // Serialize result information. + if (op.getNumResults() == 1) { + os << formatv(" uint32_t {0} = 0;\n", resultID); + emitResultSerialization(op, record->getLoc(), " ", opVar, operands, + resultID, os); + } + + // Process arguments. + emitOperandSerialization(op, record->getLoc(), " ", opVar, operands, + elidedAttrs, os); + + if (record->isSubClassOf("SPV_ExtInstOp")) { + os << formatv(" encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", + opVar, record->getValueAsString("extendedInstSetName"), + record->getValueAsInt("extendedInstOpcode"), operands); + } else { + os << formatv(" encodeInstructionInto(" + "functionBody, spirv::getOpcode<{0}>(), {1});\n", + op.getQualCppClassName(), operands); + } + + // Process decorations. + emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os); + + os << " return success();\n"; + os << "}\n\n"; +} + +/// Generates the prologue for the function that dispatches the serialization of +/// the operation `opVar` based on its opcode. +static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) { + os << formatv( + "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " + "*{0}) {{\n ", + opVar); +} + +/// Generates the body of the dispatch function. This function generates the +/// check that if satisfied, will call the serialization function generated for +/// the `op`. +static void emitSerializationDispatch(const Operator &op, StringRef tabs, + StringRef opVar, raw_ostream &os) { + os << tabs + << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar); + os << tabs + << formatv(" return processOp(cast<{0}>({1}));\n", + op.getQualCppClassName(), opVar); + os << tabs << "} else"; +} + +/// Generates the epilogue for the function that dispatches the serialization of +/// the operation. +static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) { + os << " {\n"; + os << formatv( + " return {0}->emitError(\"unhandled operation serialization\");\n", + opVar); + os << " }\n"; + os << " return success();\n"; + os << "}\n\n"; +} + +/// Generates code to deserialize the attribute of a SPV_Op into `os`. The +/// generated code reads the `words` of the serialized instruction at +/// position `wordIndex` and adds the deserialized attribute into `attrList`. +static void emitAttributeDeserialization(const Attribute &attr, + ArrayRef<SMLoc> loc, StringRef tabs, + StringRef attrList, StringRef attrName, + StringRef words, StringRef wordIndex, + raw_ostream &os) { + if (attr.getAttrDefName() == "SPV_ScopeAttr" || + attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "getConstantInt({2}[{3}++])));\n", + attrList, attrName, words, wordIndex); + } else if (attr.getAttrDefName() == "I32ArrayAttr") { + os << tabs << "SmallVector<Attribute, 4> attrListElems;\n"; + os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " " + "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))" + ";\n", + words, wordIndex); + os << tabs << "}\n"; + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "opBuilder.getArrayAttr(attrListElems)));\n", + attrList, attrName); + } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "opBuilder.getI32IntegerAttr({2}[{3}++])));\n", + attrList, attrName, words, wordIndex); + } else { + PrintFatalError( + loc, llvm::Twine( + "unhandled attribute type in deserialization generation : '") + + attr.getAttrDefName() + llvm::Twine("'")); + } +} + +/// Generates the code to deserialize the result of an SPV_Op `op` into +/// `os`. The generated code gets the type of the result specified at +/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1 +/// and updates the `resultType` and `valueID` with the parsed type and SSA ID, +/// respectively. +static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc, + StringRef tabs, StringRef words, + StringRef wordIndex, + StringRef resultTypes, StringRef valueID, + raw_ostream &os) { + // Deserialize result information if it exists + if (op.getNumResults() == 1) { + os << tabs << "{\n"; + os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " return emitError(unknownLoc, \"expected result type <id> " + "while deserializing {0}\");\n", + op.getQualCppClassName()); + os << tabs << " }\n"; + os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex); + os << tabs << " if (!ty) {\n"; + os << tabs + << formatv( + " return emitError(unknownLoc, \"unknown type result <id> : " + "\") << {0}[{1}];\n", + words, wordIndex); + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes); + os << tabs << formatv(" {0}++;\n", wordIndex); + os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " return emitError(unknownLoc, \"expected result <id> while " + "deserializing {0}\");\n", + op.getQualCppClassName()); + os << tabs << " }\n"; + os << tabs << "}\n"; + os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex); + } else if (op.getNumResults() != 0) { + PrintFatalError(loc, "SPIR-V ops can have only zero or one result"); + } +} + +/// Generates the code to deserialize the operands of an SPV_Op `op` into +/// `os`. The generated code reads the `words` of the binary instruction, from +/// position `wordIndex` to the end, and either gets the Value corresponding to +/// the ID encoded, or deserializes the attributes encoded. The parsed +/// operand(attribute) is added to the `operands` list or `attributes` list. +static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc, + StringRef tabs, StringRef words, + StringRef wordIndex, StringRef operands, + StringRef attributes, raw_ostream &os) { + // Process operands/attributes + unsigned operandNum = 0; + for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); + if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) { + if (valueArg->isVariadic()) { + if (i != e - 1) { + PrintFatalError(loc, + "SPIR-V ops can have Variadic<..> argument only if " + "it's the last argument"); + } + os << tabs + << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words); + } else { + os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words); + } + os << " {\n"; + os << tabs + << formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex); + os << tabs << " if (!arg) {\n"; + os << tabs + << formatv( + " return emitError(unknownLoc, \"unknown result <id> : \") " + "<< {0}[{1}];\n", + words, wordIndex); + os << tabs << " }\n"; + os << tabs << formatv(" {0}.push_back(arg);\n", operands); + if (!valueArg->isVariadic()) { + os << tabs << formatv(" {0}++;\n", wordIndex); + } + operandNum++; + os << tabs << "}\n"; + } else { + os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words); + auto attr = argument.get<NamedAttribute *>(); + auto newtabs = tabs.str() + " "; + emitAttributeDeserialization( + (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), + loc, newtabs, attributes, attr->name, words, wordIndex, os); + os << " }\n"; + } + } + + os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words); + os << tabs + << formatv( + " return emitError(unknownLoc, \"found more operands than " + "expected when deserializing {0}, only \") << {1} << \" of \" << " + "{2}.size() << \" processed\";\n", + op.getQualCppClassName(), wordIndex, words); + os << tabs << "}\n\n"; +} + +/// Generates code to update the `attributes` vector with the attributes +/// obtained from parsing the decorations in the SPIR-V binary associated with +/// an <id> `valueID` +static void emitDecorationDeserialization(const Operator &op, StringRef tabs, + StringRef valueID, + StringRef attributes, + raw_ostream &os) { + // Import decorations parsed + if (op.getNumResults() == 1) { + os << tabs << formatv("if (decorations.count({0})) {{\n", valueID); + os << tabs + << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID); + os << tabs + << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes); + os << tabs << "}\n"; + } +} + +/// Generates code to deserialize an SPV_Op `op` into `os`. +static void emitDeserializationFunction(const Record *attrClass, + const Record *record, + const Operator &op, raw_ostream &os) { + // If the record has 'autogenSerialization' set to 0, nothing to do + if (!record->getValueAsBit("autogenSerialization")) { + return; + } + StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"), + wordIndex("wordIndex"), opVar("op"), operands("operands"), + attributes("attributes"); + os << formatv("template <> " + "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" + "uint32_t> {1}) {{\n", + op.getQualCppClassName(), words); + os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes); + os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex); + os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID); + + // Deserialize result information + emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex, + resultTypes, valueID, os); + + os << formatv(" SmallVector<Value, 4> {0};\n", operands); + os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes); + // Operand deserialization + emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, + operands, attributes, os); + + os << formatv( + " auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); " + "(void){1};\n", + op.getQualCppClassName(), opVar, resultTypes, operands, attributes); + if (op.getNumResults() == 1) { + os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar); + } + + // Decorations + emitDecorationDeserialization(op, " ", valueID, attributes, os); + os << " return success();\n"; + os << "}\n\n"; +} + +/// Generates the prologue for the function that dispatches the deserialization +/// based on the `opcode`. +static void initDispatchDeserializationFn(StringRef opcode, StringRef words, + raw_ostream &os) { + os << formatv( + "LogicalResult " + "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, " + "ArrayRef<uint32_t> {1}) {{\n", + opcode, words); + os << formatv(" switch ({0}) {{\n", opcode); +} + +/// Generates the body of the dispatch function, by generating the case label +/// for an opcode and the call to the method to perform the deserialization. +static void emitDeserializationDispatch(const Operator &op, const Record *def, + StringRef tabs, StringRef words, + raw_ostream &os) { + os << tabs + << formatv("case spirv::Opcode::{0}:\n", + def->getValueAsString("spirvOpName")); + os << tabs + << formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(), + words); +} + +/// Generates the epilogue for the function that dispatches the deserialization +/// of the operation. +static void finalizeDispatchDeserializationFn(StringRef opcode, + raw_ostream &os) { + os << " default:\n"; + os << " ;\n"; + os << " }\n"; + StringRef opcodeVar("opcodeString"); + os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar, + opcode); + os << formatv(" if (!{0}.empty()) {{\n", opcodeVar); + os << formatv(" return emitError(unknownLoc, \"unhandled deserialization " + "of \") << {0};\n", + opcodeVar); + os << " } else {\n"; + os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << " + "static_cast<uint32_t>({0});\n", + opcode); + os << " }\n"; + os << "}\n"; +} + +static void initExtendedSetDeserializationDispatch(StringRef extensionSetName, + StringRef instructionID, + StringRef words, + raw_ostream &os) { + os << formatv("LogicalResult " + "Deserializer::dispatchToExtensionSetAutogenDeserialization(" + "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n", + extensionSetName, instructionID, words); +} + +static void +emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper, + raw_ostream &os) { + StringRef extensionSetName("extensionSetName"), + instructionID("instructionID"), words("words"); + + // First iterate over all ops derived from SPV_ExtensionSetOps to get all + // extensionSets. + + // For each of the extensions a separate raw_string_ostream is used to + // generate code into. These are then concatenated at the end. Since + // raw_string_ostream needs a string&, use a vector to store all the string + // that are captured by reference within raw_string_ostream. + StringMap<raw_string_ostream> extensionSets; + SmallVector<std::string, 1> extensionSetNames; + + initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words, + os); + auto defs = recordKeeper.getAllDerivedDefinitions("SPV_ExtInstOp"); + for (const auto *def : defs) { + if (!def->getValueAsBit("autogenSerialization")) { + continue; + } + Operator op(def); + auto setName = def->getValueAsString("extendedInstSetName"); + if (!extensionSets.count(setName)) { + extensionSetNames.push_back(""); + extensionSets.try_emplace(setName, extensionSetNames.back()); + auto &setos = extensionSets.find(setName)->second; + setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName); + setos << formatv(" switch ({0}) {{\n", instructionID); + } + auto &setos = extensionSets.find(setName)->second; + setos << formatv(" case {0}:\n", + def->getValueAsInt("extendedInstOpcode")); + setos << formatv(" return processOp<{0}>({1});\n", + op.getQualCppClassName(), words); + } + + // Append the dispatch code for all the extended sets. + for (auto &extensionSet : extensionSets) { + os << extensionSet.second.str(); + os << " default:\n"; + os << formatv( + " return emitError(unknownLoc, \"unhandled deserializations of " + "\") << {0} << \" from extension set \" << {1};\n", + instructionID, extensionSetName); + os << " }\n"; + os << " }\n"; + } + + os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of " + "extended instruction set {0}\");\n", + extensionSetName); + os << "}\n"; +} + +/// Emits all the autogenerated serialization/deserializations functions for the +/// SPV_Ops. +static bool emitSerializationFns(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os); + + std::string dSerFnString, dDesFnString, serFnString, deserFnString, + utilsString; + raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), + serFn(serFnString), deserFn(deserFnString), utils(utilsString); + auto attrClass = recordKeeper.getClass("Attr"); + + // Emit the serialization and deserialization functions simultaneously. + declareOpcodeFn(utils); + StringRef opVar("op"); + StringRef opcode("opcode"), words("words"); + + // Handle the SPIR-V ops. + initDispatchSerializationFn(opVar, dSerFn); + initDispatchDeserializationFn(opcode, words, dDesFn); + auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); + for (const auto *def : defs) { + Operator op(def); + emitSerializationFunction(attrClass, def, op, serFn); + emitDeserializationFunction(attrClass, def, op, deserFn); + if (def->getValueAsBit("hasOpcode") || def->isSubClassOf("SPV_ExtInstOp")) { + emitSerializationDispatch(op, " ", opVar, dSerFn); + } + if (def->getValueAsBit("hasOpcode")) { + emitGetOpcodeFunction(def, op, utils); + emitDeserializationDispatch(op, def, " ", words, dDesFn); + } + } + finalizeDispatchSerializationFn(opVar, dSerFn); + finalizeDispatchDeserializationFn(opcode, dDesFn); + + emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn); + + os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n"; + os << utils.str(); + os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n"; + + os << "#ifdef GET_SERIALIZATION_FNS\n\n"; + os << serFn.str(); + os << dSerFn.str(); + os << "#endif // GET_SERIALIZATION_FNS\n\n"; + + os << "#ifdef GET_DESERIALIZATION_FNS\n\n"; + os << deserFn.str(); + os << dDesFn.str(); + os << "#endif // GET_DESERIALIZATION_FNS\n\n"; + + return false; +} + +//===----------------------------------------------------------------------===// +// Op Utils AutoGen +//===----------------------------------------------------------------------===// + +static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { + os << formatv("template <typename EnumClass> inline constexpr StringRef " + "attributeName();\n"); +} + +static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) { + os << "template <typename EnumClass> using SymbolizeFnTy = " + "llvm::Optional<EnumClass> (*)(StringRef);\n"; + os << "template <typename EnumClass> inline constexpr " + "SymbolizeFnTy<EnumClass> symbolizeEnum();\n"; +} + +static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, + raw_ostream &os) { + auto enumName = enumAttr.getEnumClassName(); + os << formatv("template <> inline StringRef attributeName<{0}>() {{\n", + enumName); + os << " " + << formatv("static constexpr const char attrName[] = \"{0}\";\n", + mlir::convertToSnakeCase(enumName)); + os << " return attrName;\n"; + os << "}\n"; +} + +static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, + raw_ostream &os) { + auto enumName = enumAttr.getEnumClassName(); + auto strToSymFnName = enumAttr.getStringToSymbolFnName(); + os << formatv( + "template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>() {{\n", + enumName); + os << " return " << strToSymFnName << ";\n"; + os << "}\n"; +} + +static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Op Utilities", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + os << "#ifndef SPIRV_OP_UTILS_H_\n"; + os << "#define SPIRV_OP_UTILS_H_\n"; + emitEnumGetAttrNameFnDecl(os); + emitEnumGetSymbolizeFnDecl(os); + for (const auto *def : defs) { + EnumAttr enumAttr(*def); + emitEnumGetAttrNameFnDefn(enumAttr, os); + emitEnumGetSymbolizeFnDefn(enumAttr, os); + } + os << "#endif // SPIRV_OP_UTILS_H\n"; + return false; +} + +//===----------------------------------------------------------------------===// +// Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration genSerialization( + "gen-spirv-serialization", + "Generate SPIR-V (de)serialization utilities and functions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitSerializationFns(records, os); + }); + +static mlir::GenRegistration + genOpUtils("gen-spirv-op-utils", + "Generate SPIR-V operation utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitOpUtils(records, os); + }); |