//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the SPIR-V binary to MLIR SPIR-V module deserialization. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; #define DEBUG_TYPE "spirv-deserialization" /// Decodes a string literal in `words` starting at `wordIndex`. Update the /// latter to point to the position in words after the string literal. static inline StringRef decodeStringLiteral(ArrayRef words, unsigned &wordIndex) { StringRef str(reinterpret_cast(words.data() + wordIndex)); wordIndex += str.size() / 4 + 1; return str; } /// Extracts the opcode from the given first word of a SPIR-V instruction. static inline spirv::Opcode extractOpcode(uint32_t word) { return static_cast(word & 0xffff); } /// Returns true if the given `block` is a function entry block. static inline bool isFnEntryBlock(Block *block) { return block->isEntryBlock() && isa_and_nonnull(block->getParentOp()); } namespace { /// A struct for containing a header block's merge and continue targets. /// /// This struct is used to track original structured control flow info from /// SPIR-V blob. This info will be used to create spv.selection/spv.loop /// later. struct BlockMergeInfo { Block *mergeBlock; Block *continueBlock; // nullptr for spv.selection BlockMergeInfo() : mergeBlock(nullptr), continueBlock(nullptr) {} BlockMergeInfo(Block *m, Block *c = nullptr) : mergeBlock(m), continueBlock(c) {} }; /// Map from a selection/loop's header block to its merge (and continue) target. using BlockMergeInfoMap = DenseMap; /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each /// instruction is composed of 32-bit words. The first word of an instruction /// records the total number of words of that instruction using the 16 /// higher-order bits. So this deserializer uses that to get instruction /// boundary and parse instructions and build a SPIR-V ModuleOp gradually. /// // TODO(antiagainst): clean up created ops on errors class Deserializer { public: /// Creates a deserializer for the given SPIR-V `binary` module. /// The SPIR-V ModuleOp will be created into `context. explicit Deserializer(ArrayRef binary, MLIRContext *context); /// Deserializes the remembered SPIR-V binary module. LogicalResult deserialize(); /// Collects the final SPIR-V ModuleOp. Optional collect(); private: //===--------------------------------------------------------------------===// // Module structure //===--------------------------------------------------------------------===// /// Initializes the `module` ModuleOp in this deserializer instance. spirv::ModuleOp createModuleOp(); /// Processes SPIR-V module header in `binary`. LogicalResult processHeader(); /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping /// in the deserializer. LogicalResult processCapability(ArrayRef operands); /// Attaches all collected capabilities to `module` as an attribute. void attachCapabilities(); /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping /// in the deserializer. LogicalResult processExtension(ArrayRef words); /// Processes the SPIR-V OpExtInstImport with `operands` and updates /// bookkeeping in the deserializer. LogicalResult processExtInstImport(ArrayRef words); /// Attaches all collected extensions to `module` as an attribute. void attachExtensions(); /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`. LogicalResult processMemoryModel(ArrayRef operands); /// Process SPIR-V OpName with `operands`. LogicalResult processName(ArrayRef operands); /// Processes an OpDecorate instruction. LogicalResult processDecoration(ArrayRef words); // Processes an OpMemberDecorate instruction. LogicalResult processMemberDecoration(ArrayRef words); /// Processes an OpMemberName instruction. LogicalResult processMemberName(ArrayRef words); /// Gets the FuncOp associated with a result of OpFunction. FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. /// This method processes each instruction inside the function and dispatches /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef operands); /// Processes OpFunctionEnd and finalizes function. This wires up block /// argument created from OpPhi instructions and also structurizes control /// flow. LogicalResult processFunctionEnd(ArrayRef operands); /// Gets the constant's attribute and type associated with the given . Optional> getConstant(uint32_t id); /// Gets the constant's integer attribute with the given . Returns a null /// IntegerAttr if the given is not registered or does not correspond to an /// integer constant. IntegerAttr getConstantInt(uint32_t id); /// Returns a symbol to be used for the function name with the given /// result . This tries to use the function's OpName if /// exists; otherwise creates one based on the . std::string getFunctionSymbol(uint32_t id); /// Returns a symbol to be used for the specialization constant with the given /// result . This tries to use the specialization constant's OpName if /// exists; otherwise creates one based on the . std::string getSpecConstantSymbol(uint32_t id); /// Gets the specialization constant with the given result . spirv::SpecConstantOp getSpecConstant(uint32_t id) { return specConstMap.lookup(id); } /// Creates a spirv::SpecConstantOp. spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, Attribute defaultValue); /// Processes the OpVariable instructions at current `offset` into `binary`. /// It is expected that this method is used for variables that are to be /// defined at module scope and will be deserialized into a spv.globalVariable /// instruction. LogicalResult processGlobalVariable(ArrayRef operands); /// Gets the global variable associated with a result of OpVariable. spirv::GlobalVariableOp getGlobalVariable(uint32_t id) { return globalVariableMap.lookup(id); } //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// /// Gets type for a given result . Type getType(uint32_t id) { return typeMap.lookup(id); } /// Get the type associated with the result of an OpUndef. Type getUndefType(uint32_t id) { return undefMap.lookup(id); } /// Returns true if the given `type` is for SPIR-V void type. bool isVoidType(Type type) const { return type.isa(); } /// Processes a SPIR-V type instruction with given `opcode` and `operands` and /// registers the type into `module`. LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); LogicalResult processArrayType(ArrayRef operands); LogicalResult processFunctionType(ArrayRef operands); LogicalResult processRuntimeArrayType(ArrayRef operands); LogicalResult processStructType(ArrayRef operands); //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// /// Processes a SPIR-V Op{|Spec}Constant instruction with the given /// `operands`. `isSpec` indicates whether this is a specialization constant. LogicalResult processConstant(ArrayRef operands, bool isSpec); /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the /// given `operands`. `isSpec` indicates whether this is a specialization /// constant. LogicalResult processConstantBool(bool isTrue, ArrayRef operands, bool isSpec); /// Processes a SPIR-V OpConstantComposite instruction with the given /// `operands`. LogicalResult processConstantComposite(ArrayRef operands); /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); //===--------------------------------------------------------------------===// // Control flow //===--------------------------------------------------------------------===// /// Returns the block for the given label . Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } // In SPIR-V, structured control flow is explicitly declared using merge // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect, // we use spv.selection and spv.loop to group structured control flow. // The deserializer need to turn structured control flow marked with merge // instructions into using spv.selection/spv.loop ops. // // Because structured control flow can nest and the basic block order have // flexibility, we cannot isolate a structured selection/loop without // deserializing all the blocks. So we use the following approach: // // 1. Deserialize all basic blocks in a function and create MLIR blocks for // them into the function's region. In the meanwhile, keep a map between // selection/loop header blocks to their corresponding merge (and continue) // target blocks. // 2. For each selection/loop header block, recursively get all basic blocks // reachable (except the merge block) and put them in a newly created // spv.selection/spv.loop's region. Structured control flow guarantees // that we enter and exit in structured ways and the construct is nestable. // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge // block and redirect all branches to the old header block to the old // merge block (which contains the spv.selection/spv.loop op now). /// For OpPhi instructions, we use block arguments to represent them. OpPhi /// encodes a list of (value, predecessor) pairs. At the time of handling the /// block containing an OpPhi instruction, the predecessor block might not be /// processed yet, also the value sent by it. So we need to defer handling /// the block argument from the predecessors. We use the following approach: /// /// 1. For each OpPhi instruction, add a block argument to the current block /// in construction. Record the block argument in `valueMap` so its uses /// can be resolved. For the list of (value, predecessor) pairs, update /// `blockPhiInfo` for bookkeeping. /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each /// block recorded there to create the proper block arguments on their /// terminators. /// A data structure for containing a SPIR-V block's phi info. It will be /// represented as block argument in SPIR-V dialect. using BlockPhiInfo = SmallVector; // The result of the values sent /// Gets or creates the block corresponding to the given label . The newly /// created block will always be placed at the end of the current function. Block *getOrCreateBlock(uint32_t id); LogicalResult processBranch(ArrayRef operands); LogicalResult processBranchConditional(ArrayRef operands); /// Processes a SPIR-V OpLabel instruction with the given `operands`. LogicalResult processLabel(ArrayRef operands); /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`. LogicalResult processSelectionMerge(ArrayRef operands); /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`. LogicalResult processLoopMerge(ArrayRef operands); /// Processes a SPIR-V OpPhi instruction with the given `operands`. LogicalResult processPhi(ArrayRef operands); /// Creates block arguments on predecessors previously recorded when handling /// OpPhi instructions. LogicalResult wireUpBlockArgument(); /// Extracts blocks belonging to a structured selection/loop into a /// spv.selection/spv.loop op. This method iterates until all blocks /// declared as selection/loop headers are handled. LogicalResult structurizeControlFlow(); //===--------------------------------------------------------------------===// // Instruction //===--------------------------------------------------------------------===// /// Get the Value associated with a result . /// /// This method materializes normal constants and inserts "casting" ops /// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA /// value for handling uses of module scope constants/variables in functions. Value getValue(uint32_t id); /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if /// there is no more remaining instructions (`expectedOpcode` will be used to /// compose the error message) or the next instruction is malformed. LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, Optional expectedOpcode = llvm::None); /// Processes a SPIR-V instruction with the given `opcode` and `operands`. /// This method is the main entrance for handling SPIR-V instruction; it /// checks the instruction opcode and dispatches to the corresponding handler. /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode) /// might need to be deferred, since they contain forward references to s /// in the deserialized binary, but module in SPIR-V dialect expects these to /// be ssa-uses. LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef operands, bool deferInstructions = true); /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current /// insertion point. LogicalResult processUndef(ArrayRef operands); /// Processes an OpBitcast instruction. LogicalResult processBitcast(ArrayRef words); /// Method to dispatch to the specialized deserialization function for an /// operation in SPIR-V dialect that is a mirror of an instruction in the /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for /// all operations in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, ArrayRef words); /// Processes a SPIR-V OpExtInst with given `operands`. This slices the /// entries of `operands` that specify the extended instruction set and /// the instruction opcode. The op deserializer is then invoked using the /// other entries. LogicalResult processExtInst(ArrayRef operands); /// Dispatches the deserialization of extended instruction set operation based /// on the extended instruction set name, and instruction opcode. This is /// autogenerated from ODS. LogicalResult dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, uint32_t instructionID, ArrayRef words); /// Method to deserialize an operation in the SPIR-V dialect that is a mirror /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode /// == 1 and autogenSerialization == 1 in ODS. template LogicalResult processOp(ArrayRef words) { return emitError(unknownLoc, "unsupported deserialization for ") << OpTy::getOperationName() << " op"; } private: /// The SPIR-V binary module. ArrayRef binary; /// The current word offset into the binary module. unsigned curOffset = 0; /// MLIRContext to create SPIR-V ModuleOp into. MLIRContext *context; // TODO(antiagainst): create Location subclass for binary blob Location unknownLoc; /// The SPIR-V ModuleOp. Optional module; /// The current function under construction. Optional curFunction; /// The current block under construction. Block *curBlock = nullptr; OpBuilder opBuilder; /// The list of capabilities used by the module. llvm::SmallSetVector capabilities; /// The list of extensions used by the module. llvm::SmallSetVector extensions; // Result to type mapping. DenseMap typeMap; // Result to constant attribute and type mapping. /// /// In the SPIR-V binary format, all constants are placed in the module and /// shared by instructions at module level and in subsequent functions. But in /// the SPIR-V dialect, we materialize the constant to where it's used in the /// function. So when seeing a constant instruction in the binary format, we /// don't immediately emit a constant op into the module, we keep its value /// (and type) here. Later when it's used, we materialize the constant. DenseMap> constantMap; // Result to variable mapping. DenseMap specConstMap; // Result to variable mapping. DenseMap globalVariableMap; // Result to function mapping. DenseMap funcMap; // Result to block mapping. DenseMap blockMap; // Header block to its merge (and continue) target mapping. BlockMergeInfoMap blockMergeInfo; // Block to its phi (block argument) mapping. DenseMap blockPhiInfo; // Result to value mapping. DenseMap valueMap; // Mapping from result to undef value of a type. DenseMap undefMap; // Result to name mapping. DenseMap nameMap; // Result to decorations mapping. DenseMap decorations; // Result to type decorations. DenseMap typeDecorations; // Result to member decorations. // decorated-struct-type- -> // (struct-member-index -> (decoration -> decoration-operands)) DenseMap>>> memberDecorationMap; // Result to member name. // struct-type- -> (struct-member-index -> name) DenseMap> memberNameMap; // Result to extended instruction set name. DenseMap extendedInstSets; // List of instructions that are processed in a deferred fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function // s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and // spv.ExecutionMode) need these references resolved. So these instructions // are deserialized and stored for processing once the entire binary is // processed. SmallVector>, 4> deferredInstructions; }; } // namespace Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), module(createModuleOp()), opBuilder(module->body()) {} LogicalResult Deserializer::deserialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); if (failed(processHeader())) return failure(); spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef operands; auto binarySize = binary.size(); while (curOffset < binarySize) { // Slice the next instruction out and populate `opcode` and `operands`. // Internally this also updates `curOffset`. if (failed(sliceInstruction(opcode, operands))) return failure(); if (failed(processInstruction(opcode, operands))) return failure(); } assert(curOffset == binarySize && "deserializer should never index beyond the binary end"); for (auto &deferred : deferredInstructions) { if (failed(processInstruction(deferred.first, deferred.second, false))) { return failure(); } } // Attaches the capabilities/extensions as an attribute to the module. attachCapabilities(); attachExtensions(); LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); return success(); } Optional Deserializer::collect() { return module; } //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// spirv::ModuleOp Deserializer::createModuleOp() { Builder builder(context); OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); // TODO(antiagainst): use target environment to select the version state.addAttribute("major_version", builder.getI32IntegerAttr(1)); state.addAttribute("minor_version", builder.getI32IntegerAttr(0)); spirv::ModuleOp::build(&builder, state); return cast(Operation::create(state)); } LogicalResult Deserializer::processHeader() { if (binary.size() < spirv::kHeaderWordCount) return emitError(unknownLoc, "SPIR-V binary module must have a 5-word header"); if (binary[0] != spirv::kMagicNumber) return emitError(unknownLoc, "incorrect magic number"); // TODO(antiagainst): generator number, bound, schema curOffset = spirv::kHeaderWordCount; return success(); } LogicalResult Deserializer::processCapability(ArrayRef operands) { if (operands.size() != 1) return emitError(unknownLoc, "OpMemoryModel must have one parameter"); auto cap = spirv::symbolizeCapability(operands[0]); if (!cap) return emitError(unknownLoc, "unknown capability: ") << operands[0]; capabilities.insert(*cap); return success(); } void Deserializer::attachCapabilities() { if (capabilities.empty()) return; SmallVector caps; caps.reserve(capabilities.size()); for (auto cap : capabilities) { caps.push_back(spirv::stringifyCapability(cap)); } module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps)); } LogicalResult Deserializer::processExtension(ArrayRef words) { if (words.empty()) { return emitError( unknownLoc, "OpExtension must have a literal string for the extension name"); } unsigned wordIndex = 0; StringRef extName = decodeStringLiteral(words, wordIndex); if (wordIndex != words.size()) { return emitError(unknownLoc, "unexpected trailing words in OpExtension instruction"); } extensions.insert(extName); return success(); } LogicalResult Deserializer::processExtInstImport(ArrayRef words) { if (words.size() < 2) { return emitError(unknownLoc, "OpExtInstImport must have a result and a literal " "string for the extended instruction set name"); } unsigned wordIndex = 1; extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); if (wordIndex != words.size()) { return emitError(unknownLoc, "unexpected trailing words in OpExtInstImport"); } return success(); } void Deserializer::attachExtensions() { if (extensions.empty()) return; module->setAttr("extensions", opBuilder.getStrArrayAttr(extensions.getArrayRef())); } LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { if (operands.size() != 2) return emitError(unknownLoc, "OpMemoryModel must have two operands"); module->setAttr( "addressing_model", opBuilder.getI32IntegerAttr(llvm::bit_cast(operands.front()))); module->setAttr( "memory_model", opBuilder.getI32IntegerAttr(llvm::bit_cast(operands.back()))); return success(); } LogicalResult Deserializer::processDecoration(ArrayRef words) { // TODO : This function should also be auto-generated. For now, since only a // few decorations are processed/handled in a meaningful manner, going with a // manual implementation. if (words.size() < 2) { return emitError( unknownLoc, "OpDecorate must have at least result and Decoration"); } auto decorationName = stringifyDecoration(static_cast(words[1])); if (decorationName.empty()) { return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; } auto attrName = convertToSnakeCase(decorationName); auto symbol = opBuilder.getIdentifier(attrName); switch (static_cast(words[1])) { case spirv::Decoration::DescriptorSet: case spirv::Decoration::Binding: if (words.size() != 3) { return emitError(unknownLoc, "OpDecorate with ") << decorationName << " needs a single integer literal"; } decorations[words[0]].set( symbol, opBuilder.getI32IntegerAttr(static_cast(words[2]))); break; case spirv::Decoration::BuiltIn: if (words.size() != 3) { return emitError(unknownLoc, "OpDecorate with ") << decorationName << " needs a single integer literal"; } decorations[words[0]].set( symbol, opBuilder.getStringAttr( stringifyBuiltIn(static_cast(words[2])))); break; case spirv::Decoration::ArrayStride: if (words.size() != 3) { return emitError(unknownLoc, "OpDecorate with ") << decorationName << " needs a single integer literal"; } typeDecorations[words[0]] = words[2]; break; case spirv::Decoration::Block: case spirv::Decoration::BufferBlock: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; } // Block decoration does not affect spv.struct type, but is still stored for // verification. // TODO: Update StructType to contain this information since // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::SpecId: if (words.size() != 3) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single integer literal"; } decorations[words[0]].set( symbol, opBuilder.getI32IntegerAttr(static_cast(words[2]))); break; default: return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; } return success(); } LogicalResult Deserializer::processMemberDecoration(ArrayRef words) { // The binary layout of OpMemberDecorate is different comparing to OpDecorate if (words.size() < 3) { return emitError(unknownLoc, "OpMemberDecorate must have at least 3 operands"); } auto decoration = static_cast(words[2]); if (decoration == spirv::Decoration::Offset && words.size() != 4) { return emitError(unknownLoc, " missing offset specification in OpMemberDecorate with " "Offset decoration"); } ArrayRef decorationOperands; if (words.size() > 3) { decorationOperands = words.slice(3); } memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; return success(); } LogicalResult Deserializer::processMemberName(ArrayRef words) { if (words.size() < 3) { return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); } unsigned wordIndex = 2; auto name = decodeStringLiteral(words, wordIndex); if (wordIndex != words.size()) { return emitError(unknownLoc, "unexpected trailing words in OpMemberName instruction"); } memberNameMap[words[0]][words[1]] = name; return success(); } LogicalResult Deserializer::processFunction(ArrayRef operands) { if (curFunction) { return emitError(unknownLoc, "found function inside function"); } // Get the result type if (operands.size() != 4) { return emitError(unknownLoc, "OpFunction must have 4 parameters"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } if (funcMap.count(operands[1])) { return emitError(unknownLoc, "duplicate function definition/declaration"); } auto functionControl = spirv::symbolizeFunctionControl(operands[2]); if (!functionControl) { return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; } if (functionControl.getValue() != spirv::FunctionControl::None) { /// TODO : Handle different function controls return emitError(unknownLoc, "unhandled Function Control: '") << spirv::stringifyFunctionControl(functionControl.getValue()) << "'"; } Type fnType = getType(operands[3]); if (!fnType || !fnType.isa()) { return emitError(unknownLoc, "unknown function type from ") << operands[3]; } auto functionType = fnType.cast(); if ((isVoidType(resultType) && functionType.getNumResults() != 0) || (functionType.getNumResults() == 1 && functionType.getResult(0) != resultType)) { return emitError(unknownLoc, "mismatch in function type ") << functionType << " and return type " << resultType << " specified"; } std::string fnName = getFunctionSymbol(operands[1]); auto funcOp = opBuilder.create(unknownLoc, fnName, functionType, ArrayRef()); curFunction = funcMap[operands[1]] = funcOp; LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " << fnType << ", id = " << operands[1] << ") --\n"); auto *entryBlock = funcOp.addEntryBlock(); LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock << "\n"); // Parse the op argument instructions if (functionType.getNumInputs()) { for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { auto argType = functionType.getInput(i); spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef operands; if (failed(sliceInstruction(opcode, operands, spirv::Opcode::OpFunctionParameter))) { return failure(); } if (opcode != spirv::Opcode::OpFunctionParameter) { return emitError( unknownLoc, "missing OpFunctionParameter instruction for argument ") << i; } if (operands.size() != 2) { return emitError( unknownLoc, "expected result type and result for OpFunctionParameter"); } auto argDefinedType = getType(operands[0]); if (!argDefinedType || argDefinedType != argType) { return emitError(unknownLoc, "mismatch in argument type between function type " "definition ") << functionType << " and argument type definition " << argDefinedType << " at argument " << i; } if (getValue(operands[1])) { return emitError(unknownLoc, "duplicate definition of result '") << operands[1]; } auto argValue = funcOp.getArgument(i); valueMap[operands[1]] = argValue; } } // RAII guard to reset the insertion point to the module's region after // deserializing the body of this function. OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef instOperands; // Special handling for the entry block. We need to make sure it starts with // an OpLabel instruction. The entry block takes the same parameters as the // function. All other blocks do not take any parameter. We have already // created the entry block, here we need to register it to the correct label // . if (failed(sliceInstruction(opcode, instOperands, spirv::Opcode::OpFunctionEnd))) { return failure(); } if (opcode == spirv::Opcode::OpFunctionEnd) { LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " << fnType << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } if (opcode != spirv::Opcode::OpLabel) { return emitError(unknownLoc, "a basic block must start with OpLabel"); } if (instOperands.size() != 1) { return emitError(unknownLoc, "OpLabel should only have result "); } blockMap[instOperands[0]] = entryBlock; if (failed(processLabel(instOperands))) { return failure(); } // Then process all the other instructions in the function until we hit // OpFunctionEnd. while (succeeded(sliceInstruction(opcode, instOperands, spirv::Opcode::OpFunctionEnd)) && opcode != spirv::Opcode::OpFunctionEnd) { if (failed(processInstruction(opcode, instOperands))) { return failure(); } } if (opcode != spirv::Opcode::OpFunctionEnd) { return failure(); } LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " << fnType << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } LogicalResult Deserializer::processFunctionEnd(ArrayRef operands) { // Process OpFunctionEnd. if (!operands.empty()) { return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); } // Wire up block arguments from OpPhi instructions. // Put all structured control flow in spv.selection/spv.loop ops. if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { return failure(); } curBlock = nullptr; curFunction = llvm::None; return success(); } Optional> Deserializer::getConstant(uint32_t id) { auto constIt = constantMap.find(id); if (constIt == constantMap.end()) return llvm::None; return constIt->getSecond(); } std::string Deserializer::getFunctionSymbol(uint32_t id) { auto funcName = nameMap.lookup(id).str(); if (funcName.empty()) { funcName = "spirv_fn_" + std::to_string(id); } return funcName; } std::string Deserializer::getSpecConstantSymbol(uint32_t id) { auto constName = nameMap.lookup(id).str(); if (constName.empty()) { constName = "spirv_spec_const_" + std::to_string(id); } return constName; } spirv::SpecConstantOp Deserializer::createSpecConstant(Location loc, uint32_t resultID, Attribute defaultValue) { auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); auto op = opBuilder.create(unknownLoc, symName, defaultValue); if (decorations.count(resultID)) { for (auto attr : decorations[resultID].getAttrs()) op.setAttr(attr.first, attr.second); } specConstMap[resultID] = op; return op; } LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { unsigned wordIndex = 0; if (operands.size() < 3) { return emitError( unknownLoc, "OpVariable needs at least 3 operands, type, and storage class"); } // Result Type. auto type = getType(operands[wordIndex]); if (!type) { return emitError(unknownLoc, "unknown result type : ") << operands[wordIndex]; } auto ptrType = type.dyn_cast(); if (!ptrType) { return emitError(unknownLoc, "expected a result type to be a spv.ptr, found : ") << type; } wordIndex++; // Result . auto variableID = operands[wordIndex]; auto variableName = nameMap.lookup(variableID).str(); if (variableName.empty()) { variableName = "spirv_var_" + std::to_string(variableID); } wordIndex++; // Storage class. auto storageClass = static_cast(operands[wordIndex]); if (ptrType.getStorageClass() != storageClass) { return emitError(unknownLoc, "mismatch in storage class of pointer type ") << type << " and that specified in OpVariable instruction : " << stringifyStorageClass(storageClass); } wordIndex++; // Initializer. FlatSymbolRefAttr initializer = nullptr; if (wordIndex < operands.size()) { auto initializerOp = getGlobalVariable(operands[wordIndex]); if (!initializerOp) { return emitError(unknownLoc, "unknown ") << operands[wordIndex] << "used as initializer"; } wordIndex++; initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); } if (wordIndex != operands.size()) { return emitError(unknownLoc, "found more operands than expected when deserializing " "OpVariable instruction, only ") << wordIndex << " of " << operands.size() << " processed"; } auto varOp = opBuilder.create( unknownLoc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), initializer); // Decorations. if (decorations.count(variableID)) { for (auto attr : decorations[variableID].getAttrs()) { varOp.setAttr(attr.first, attr.second); } } globalVariableMap[variableID] = varOp; return success(); } IntegerAttr Deserializer::getConstantInt(uint32_t id) { auto constInfo = getConstant(id); if (!constInfo) { return nullptr; } return constInfo->first.dyn_cast(); } LogicalResult Deserializer::processName(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpName needs at least 2 operands"); } if (!nameMap.lookup(operands[0]).empty()) { return emitError(unknownLoc, "duplicate name found for result ") << operands[0]; } unsigned wordIndex = 1; StringRef name = decodeStringLiteral(operands, wordIndex); if (wordIndex != operands.size()) { return emitError(unknownLoc, "unexpected trailing words in OpName instruction"); } nameMap[operands[0]] = name; return success(); } //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// LogicalResult Deserializer::processType(spirv::Opcode opcode, ArrayRef operands) { if (operands.empty()) { return emitError(unknownLoc, "type instruction with opcode ") << spirv::stringifyOpcode(opcode) << " needs at least one "; } /// TODO: Types might be forward declared in some instructions and need to be /// handled appropriately. if (typeMap.count(operands[0])) { return emitError(unknownLoc, "duplicate definition for result ") << operands[0]; } switch (opcode) { case spirv::Opcode::OpTypeVoid: if (operands.size() != 1) { return emitError(unknownLoc, "OpTypeVoid must have no parameters"); } typeMap[operands[0]] = opBuilder.getNoneType(); break; case spirv::Opcode::OpTypeBool: if (operands.size() != 1) { return emitError(unknownLoc, "OpTypeBool must have no parameters"); } typeMap[operands[0]] = opBuilder.getI1Type(); break; case spirv::Opcode::OpTypeInt: if (operands.size() != 3) { return emitError( unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); } // TODO: Ignoring the signedness right now. Need to handle this effectively // in the MLIR representation. typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]); break; case spirv::Opcode::OpTypeFloat: { if (operands.size() != 2) { return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); } Type floatTy; switch (operands[1]) { case 16: floatTy = opBuilder.getF16Type(); break; case 32: floatTy = opBuilder.getF32Type(); break; case 64: floatTy = opBuilder.getF64Type(); break; default: return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") << operands[1]; } typeMap[operands[0]] = floatTy; } break; case spirv::Opcode::OpTypeVector: { if (operands.size() != 3) { return emitError( unknownLoc, "OpTypeVector must have element type and count parameters"); } Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, "OpTypeVector references undefined ") << operands[1]; } typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); } break; case spirv::Opcode::OpTypePointer: { if (operands.size() != 3) { return emitError(unknownLoc, "OpTypePointer must have two parameters"); } auto pointeeType = getType(operands[2]); if (!pointeeType) { return emitError(unknownLoc, "unknown OpTypePointer pointee type ") << operands[2]; } auto storageClass = static_cast(operands[1]); typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); case spirv::Opcode::OpTypeRuntimeArray: return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: return processStructType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } return success(); } LogicalResult Deserializer::processArrayType(ArrayRef operands) { if (operands.size() != 3) { return emitError(unknownLoc, "OpTypeArray must have element type and count parameters"); } Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, "OpTypeArray references undefined ") << operands[1]; } unsigned count = 0; // TODO(antiagainst): The count can also come frome a specialization constant. auto countInfo = getConstant(operands[2]); if (!countInfo) { return emitError(unknownLoc, "OpTypeArray count ") << operands[2] << "can only come from normal constant right now"; } if (auto intVal = countInfo->first.dyn_cast()) { count = intVal.getInt(); } else { return emitError(unknownLoc, "OpTypeArray count must come from a " "scalar integer constant instruction"); } typeMap[operands[0]] = spirv::ArrayType::get( elementTy, count, typeDecorations.lookup(operands[0])); return success(); } LogicalResult Deserializer::processFunctionType(ArrayRef operands) { assert(!operands.empty() && "No operands for processing function type"); if (operands.size() == 1) { return emitError(unknownLoc, "missing return type for OpTypeFunction"); } auto returnType = getType(operands[1]); if (!returnType) { return emitError(unknownLoc, "unknown return type in OpTypeFunction"); } SmallVector argTypes; for (size_t i = 2, e = operands.size(); i < e; ++i) { auto ty = getType(operands[i]); if (!ty) { return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); } argTypes.push_back(ty); } ArrayRef returnTypes; if (!isVoidType(returnType)) { returnTypes = llvm::makeArrayRef(returnType); } typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); return success(); } LogicalResult Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); } Type memberType = getType(operands[1]); if (!memberType) { return emitError(unknownLoc, "OpTypeRuntimeArray references undefined ") << operands[1]; } typeMap[operands[0]] = spirv::RuntimeArrayType::get(memberType); return success(); } LogicalResult Deserializer::processStructType(ArrayRef operands) { if (operands.empty()) { return emitError(unknownLoc, "OpTypeStruct must have at least result "); } if (operands.size() == 1) { // Handle empty struct. typeMap[operands[0]] = spirv::StructType::getEmpty(context); return success(); } SmallVector memberTypes; for (auto op : llvm::drop_begin(operands, 1)) { Type memberType = getType(op); if (!memberType) { return emitError(unknownLoc, "OpTypeStruct references undefined ") << op; } memberTypes.push_back(memberType); } SmallVector layoutInfo; SmallVector memberDecorationsInfo; if (memberDecorationMap.count(operands[0])) { auto &allMemberDecorations = memberDecorationMap[operands[0]]; for (auto memberIndex : llvm::seq(0, memberTypes.size())) { if (allMemberDecorations.count(memberIndex)) { for (auto &memberDecoration : allMemberDecorations[memberIndex]) { // Check for offset. if (memberDecoration.first == spirv::Decoration::Offset) { // If layoutInfo is empty, resize to the number of members; if (layoutInfo.empty()) { layoutInfo.resize(memberTypes.size()); } layoutInfo[memberIndex] = memberDecoration.second[0]; } else { if (!memberDecoration.second.empty()) { return emitError(unknownLoc, "unhandled OpMemberDecoration with decoration ") << stringifyDecoration(memberDecoration.first) << " which has additional operands"; } memberDecorationsInfo.emplace_back(memberIndex, memberDecoration.first); } } } } } typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo); // TODO(ravishankarm): Update StructType to have member name as attribute as // well. return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// LogicalResult Deserializer::processConstant(ArrayRef operands, bool isSpec) { StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; if (operands.size() < 2) { return emitError(unknownLoc) << opname << " must have type and result "; } if (operands.size() < 3) { return emitError(unknownLoc) << opname << " must have at least 1 more parameter"; } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { if (bitwidth == 64) { if (operands.size() == 4) { return success(); } return emitError(unknownLoc) << opname << " should have 2 parameters for 64-bit values"; } if (bitwidth <= 32) { if (operands.size() == 3) { return success(); } return emitError(unknownLoc) << opname << " should have 1 parameter for values with no more than 32 bits"; } return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") << bitwidth; }; auto resultID = operands[1]; if (auto intType = resultType.dyn_cast()) { auto bitwidth = intType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); } APInt value; if (bitwidth == 64) { // 64-bit integers are represented with two SPIR-V words. According to // SPIR-V spec: "When the type’s bit width is larger than one word, the // literal’s low-order words appear first." struct DoubleWord { uint32_t word1; uint32_t word2; } words = {operands[2], operands[3]}; value = APInt(64, llvm::bit_cast(words), /*isSigned=*/true); } else if (bitwidth <= 32) { value = APInt(bitwidth, operands[2], /*isSigned=*/true); } auto attr = opBuilder.getIntegerAttr(intType, value); if (isSpec) { createSpecConstant(unknownLoc, resultID, attr); } else { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, intType); } return success(); } if (auto floatType = resultType.dyn_cast()) { auto bitwidth = floatType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); } APFloat value(0.f); if (floatType.isF64()) { // Double values are represented with two SPIR-V words. According to // SPIR-V spec: "When the type’s bit width is larger than one word, the // literal’s low-order words appear first." struct DoubleWord { uint32_t word1; uint32_t word2; } words = {operands[2], operands[3]}; value = APFloat(llvm::bit_cast(words)); } else if (floatType.isF32()) { value = APFloat(llvm::bit_cast(operands[2])); } else if (floatType.isF16()) { APInt data(16, operands[2]); value = APFloat(APFloat::IEEEhalf(), data); } auto attr = opBuilder.getFloatAttr(floatType, value); if (isSpec) { createSpecConstant(unknownLoc, resultID, attr); } else { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, floatType); } return success(); } return emitError(unknownLoc, "OpConstant can only generate values of " "scalar integer or floating-point type"); } LogicalResult Deserializer::processConstantBool(bool isTrue, ArrayRef operands, bool isSpec) { if (operands.size() != 2) { return emitError(unknownLoc, "Op") << (isSpec ? "Spec" : "") << "Constant" << (isTrue ? "True" : "False") << " must have type and result "; } auto attr = opBuilder.getBoolAttr(isTrue); auto resultID = operands[1]; if (isSpec) { createSpecConstant(unknownLoc, resultID, attr); } else { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); } return success(); } LogicalResult Deserializer::processConstantComposite(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpConstantComposite must have type and result "); } if (operands.size() < 3) { return emitError(unknownLoc, "OpConstantComposite must have at least 1 parameter"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } SmallVector elements; elements.reserve(operands.size() - 2); for (unsigned i = 2, e = operands.size(); i < e; ++i) { auto elementInfo = getConstant(operands[i]); if (!elementInfo) { return emitError(unknownLoc, "OpConstantComposite component ") << operands[i] << " must come from a normal constant"; } elements.push_back(elementInfo->first); } auto resultID = operands[1]; if (auto vectorType = resultType.dyn_cast()) { auto attr = DenseElementsAttr::get(vectorType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); } else if (auto arrayType = resultType.dyn_cast()) { auto attr = opBuilder.getArrayAttr(elements); constantMap.try_emplace(resultID, attr, resultType); } else { return emitError(unknownLoc, "unsupported OpConstantComposite type: ") << resultType; } return success(); } LogicalResult Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpConstantNull must have type and result "); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } auto resultID = operands[1]; if (resultType.isa() || resultType.isa() || resultType.isa()) { auto attr = opBuilder.getZeroAttr(resultType); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); return success(); } return emitError(unknownLoc, "unsupported OpConstantNull type: ") << resultType; } //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// Block *Deserializer::getOrCreateBlock(uint32_t id) { if (auto *block = getBlock(id)) { LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id << " @ " << block << "\n"); return block; } // We don't know where this block will be placed finally (in a spv.selection // or spv.loop or function). Create it into the function for now and sort // out the proper place later. auto *block = curFunction->addBlock(); LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " << block << "\n"); return blockMap[id] = block; } LogicalResult Deserializer::processBranch(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpBranch must appear inside a block"); } if (operands.size() != 1) { return emitError(unknownLoc, "OpBranch must take exactly one target label"); } auto *target = getOrCreateBlock(operands[0]); opBuilder.create(unknownLoc, target); return success(); } LogicalResult Deserializer::processBranchConditional(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpBranchConditional must appear inside a block"); } if (operands.size() != 3 && operands.size() != 5) { return emitError(unknownLoc, "OpBranchConditional must have condition, true label, " "false label, and optionally two branch weights"); } auto condition = getValue(operands[0]); auto *trueBlock = getOrCreateBlock(operands[1]); auto *falseBlock = getOrCreateBlock(operands[2]); Optional> weights; if (operands.size() == 5) { weights = std::make_pair(operands[3], operands[4]); } opBuilder.create( unknownLoc, condition, trueBlock, /*trueArguments=*/ArrayRef(), falseBlock, /*falseArguments=*/ArrayRef(), weights); return success(); } LogicalResult Deserializer::processLabel(ArrayRef operands) { if (!curFunction) { return emitError(unknownLoc, "OpLabel must appear inside a function"); } if (operands.size() != 1) { return emitError(unknownLoc, "OpLabel should only have result "); } auto labelID = operands[0]; // We may have forward declared this block. auto *block = getOrCreateBlock(labelID); LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); // If we have seen this block, make sure it was just a forward declaration. assert(block->empty() && "re-deserialize the same block!"); opBuilder.setInsertionPointToStart(block); blockMap[labelID] = curBlock = block; return success(); } LogicalResult Deserializer::processSelectionMerge(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); } if (operands.size() < 2) { return emitError( unknownLoc, "OpSelectionMerge must specify merge target and selection control"); } if (static_cast(spirv::SelectionControl::None) != operands[1]) { return emitError(unknownLoc, "unimplmented OpSelectionMerge selection control: ") << operands[2]; } auto *mergeBlock = getOrCreateBlock(operands[0]); if (!blockMergeInfo.try_emplace(curBlock, mergeBlock).second) { return emitError( unknownLoc, "a block cannot have more than one OpSelectionMerge instruction"); } return success(); } LogicalResult Deserializer::processLoopMerge(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpLoopMerge must appear in a block"); } if (operands.size() < 3) { return emitError(unknownLoc, "OpLoopMerge must specify merge target, " "continue target and loop control"); } if (static_cast(spirv::LoopControl::None) != operands[2]) { return emitError(unknownLoc, "unimplmented OpLoopMerge loop control: ") << operands[2]; } auto *mergeBlock = getOrCreateBlock(operands[0]); auto *continueBlock = getOrCreateBlock(operands[1]); if (!blockMergeInfo.try_emplace(curBlock, mergeBlock, continueBlock).second) { return emitError( unknownLoc, "a block cannot have more than one OpLoopMerge instruction"); } return success(); } LogicalResult Deserializer::processPhi(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpPhi must appear in a block"); } if (operands.size() < 4) { return emitError(unknownLoc, "OpPhi must specify result type, result , " "and variable-parent pairs"); } // Create a block argument for this OpPhi instruction. Type blockArgType = getType(operands[0]); BlockArgument blockArg = curBlock->addArgument(blockArgType); valueMap[operands[1]] = blockArg; LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg << " id = " << operands[1] << " of type " << blockArgType << '\n'); // For each (value, predecessor) pair, insert the value to the predecessor's // blockPhiInfo entry so later we can fix the block argument there. for (unsigned i = 2, e = operands.size(); i < e; i += 2) { uint32_t value = operands[i]; Block *predecessor = getOrCreateBlock(operands[i + 1]); blockPhiInfo[predecessor].push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor << " with arg id = " << value << '\n'); } return success(); } namespace { /// A class for putting all blocks in a structured selection/loop in a /// spv.selection/spv.loop op. class ControlFlowStructurizer { public: /// Structurizes the loop at the given `headerBlock`. /// /// This method will create an spv.loop op in the `mergeBlock` and move all /// blocks in the structured loop into the spv.loop's region. All branches to /// the `headerBlock` will be redirected to the `mergeBlock`. /// This method will also update `mergeInfo` by remapping all blocks inside to /// the newly cloned ones inside structured control flow op's regions. static LogicalResult structurize(Location loc, BlockMergeInfoMap &mergeInfo, Block *headerBlock, Block *mergeBlock, Block *continueBlock) { return ControlFlowStructurizer(loc, mergeInfo, headerBlock, mergeBlock, continueBlock) .structurizeImpl(); } private: ControlFlowStructurizer(Location loc, BlockMergeInfoMap &mergeInfo, Block *header, Block *merge, Block *cont) : location(loc), blockMergeInfo(mergeInfo), headerBlock(header), mergeBlock(merge), continueBlock(cont) {} /// Creates a new spv.selection op at the beginning of the `mergeBlock`. spirv::SelectionOp createSelectionOp(); /// Creates a new spv.loop op at the beginning of the `mergeBlock`. spirv::LoopOp createLoopOp(); /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. void collectBlocksInConstruct(); LogicalResult structurizeImpl(); Location location; BlockMergeInfoMap &blockMergeInfo; Block *headerBlock; Block *mergeBlock; Block *continueBlock; // nullptr for spv.selection llvm::SetVector constructBlocks; }; } // namespace spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() { // Create a builder and set the insertion point to the beginning of the // merge block so that the newly created SelectionOp will be inserted there. OpBuilder builder(&mergeBlock->front()); auto control = builder.getI32IntegerAttr( static_cast(spirv::SelectionControl::None)); auto selectionOp = builder.create(location, control); selectionOp.addMergeBlock(); return selectionOp; } spirv::LoopOp ControlFlowStructurizer::createLoopOp() { // Create a builder and set the insertion point to the beginning of the // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); // TODO(antiagainst): handle loop control properly auto loopOp = builder.create(location); loopOp.addEntryAndMergeBlock(); return loopOp; } void ControlFlowStructurizer::collectBlocksInConstruct() { assert(constructBlocks.empty() && "expected empty constructBlocks"); // Put the header block in the work list first. constructBlocks.insert(headerBlock); // For each item in the work list, add its successors excluding the merge // block. for (unsigned i = 0; i < constructBlocks.size(); ++i) { for (auto *successor : constructBlocks[i]->getSuccessors()) if (successor != mergeBlock) constructBlocks.insert(successor); } } LogicalResult ControlFlowStructurizer::structurizeImpl() { Operation *op = nullptr; bool isLoop = continueBlock != nullptr; if (isLoop) { if (auto loopOp = createLoopOp()) op = loopOp.getOperation(); } else { if (auto selectionOp = createSelectionOp()) op = selectionOp.getOperation(); } if (!op) return failure(); Region &body = op->getRegion(0); BlockAndValueMapping mapper; // All references to the old merge block should be directed to the // selection/loop merge block in the SelectionOp/LoopOp's region. mapper.map(mergeBlock, &body.back()); collectBlocksInConstruct(); // We've identified all blocks belonging to the selection/loop's region. Now // need to "move" them into the selection/loop. Instead of really moving the // blocks, in the following we copy them and remap all values and branches. // This is because: // * Inserting a block into a region requires the block not in any region // before. But selections/loops can nest so we can create selection/loop ops // in a nested manner, which means some blocks may already be in a // selection/loop region when to be moved again. // * It's much trickier to fix up the branches into and out of the loop's // region: we need to treat not-moved blocks and moved blocks differently: // Not-moved blocks jumping to the loop header block need to jump to the // merge point containing the new loop op but not the loop continue block's // back edge. Moved blocks jumping out of the loop need to jump to the // merge block inside the loop region but not other not-moved blocks. // We cannot use replaceAllUsesWith clearly and it's harder to follow the // logic. // Create a corresponding block in the SelectionOp/LoopOp's region for each // block in this loop construct. OpBuilder builder(body); for (auto *block : constructBlocks) { // Create a block and insert it before the selection/loop merge block in the // SelectionOp/LoopOp's region. auto *newBlock = builder.createBlock(&body.back()); mapper.map(block, newBlock); LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock << " from block " << block << "\n"); if (!isFnEntryBlock(block)) { for (BlockArgument blockArg : block->getArguments()) { auto newArg = newBlock->addArgument(blockArg.getType()); mapper.map(blockArg, newArg); LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg << " to " << newArg << '\n'); } } else { LLVM_DEBUG(llvm::dbgs() << "[cf] block " << block << " is a function entry block\n"); } for (auto &op : *block) newBlock->push_back(op.clone(mapper)); } // Go through all ops and remap the operands. auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) if (auto mappedOp = mapper.lookupOrNull(operand.get())) operand.set(mappedOp); for (auto &succOp : op->getBlockOperands()) if (auto mappedOp = mapper.lookupOrNull(succOp.get())) succOp.set(mappedOp); }; for (auto &block : body) { block.walk(remapOperands); } // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to // the selection/loop construct into its region. Next we need to fix the // connections between this new SelectionOp/LoopOp with existing blocks. // All existing incoming branches should go to the merge block, where the // SelectionOp/LoopOp resides right now. headerBlock->replaceAllUsesWith(mergeBlock); if (isLoop) { // The loop selection/loop header block may have block arguments. Since now // we place the selection/loop op inside the old merge block, we need to // make sure the old merge block has the same block argument list. assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); for (BlockArgument blockArg : headerBlock->getArguments()) { mergeBlock->addArgument(blockArg.getType()); } // If the loop header block has block arguments, make sure the spv.branch op // matches. SmallVector blockArgs; if (!headerBlock->args_empty()) blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; // The loop entry block should have a unconditional branch jumping to the // loop header block. builder.setInsertionPointToEnd(&body.front()); builder.create(location, mapper.lookupOrNull(headerBlock), ArrayRef(blockArgs)); } // All the blocks cloned into the SelectionOp/LoopOp's region can now be // cleaned up. LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n"); // First we need to drop all operands' references inside all blocks. This is // needed because we can have blocks referencing SSA values from one another. for (auto *block : constructBlocks) block->dropAllReferences(); // Then erase all old blocks. for (auto *block : constructBlocks) { // We've cloned all blocks belonging to this construct into the structured // control flow op's region. Among these blocks, some may compose another // selection/loop. If so, they will be recorded within blockMergeInfo. // We need to update the pointers there to the newly remapped ones so we can // continue structurizing them later. // TODO(antiagainst): The asserts in the following assumes input SPIR-V blob // forms correctly nested selection/loop constructs. We should relax this // and support error cases better. auto it = blockMergeInfo.find(block); if (it != blockMergeInfo.end()) { Block *newHeader = mapper.lookupOrNull(block); assert(newHeader && "nested loop header block should be remapped!"); Block *newContinue = it->second.continueBlock; if (newContinue) { newContinue = mapper.lookupOrNull(newContinue); assert(newContinue && "nested loop continue block should be remapped!"); } Block *newMerge = it->second.mergeBlock; if (Block *mappedTo = mapper.lookupOrNull(newMerge)) newMerge = mappedTo; // The iterator should be erased before adding a new entry into // blockMergeInfo to avoid iterator invalidation. blockMergeInfo.erase(it); blockMergeInfo.try_emplace(newHeader, newMerge, newContinue); } // The structured selection/loop's entry block does not have arguments. // If the function's header block is also part of the structured control // flow, we cannot just simply erase it because it may contain arguments // matching the function signature and used by the cloned blocks. if (isFnEntryBlock(block)) { LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block << " to only contain a spv.Branch op\n"); // Still keep the function entry block for the potential block arguments, // but replace all ops inside with a branch to the merge block. block->clear(); builder.setInsertionPointToEnd(block); builder.create(location, mergeBlock); } else { LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n"); block->erase(); } } LLVM_DEBUG( llvm::dbgs() << "[cf] after structurizing construct with header block " << headerBlock << ":\n" << *op << '\n'); return success(); } LogicalResult Deserializer::wireUpBlockArgument() { LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); OpBuilder::InsertionGuard guard(opBuilder); for (const auto &info : blockPhiInfo) { Block *block = info.first; const BlockPhiInfo &phiInfo = info.second; LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); // Set insertion point to before this block's terminator early because we // may materialize ops via getValue() call. auto *op = block->getTerminator(); opBuilder.setInsertionPoint(op); SmallVector blockArgs; blockArgs.reserve(phiInfo.size()); for (uint32_t valueId : phiInfo) { if (Value value = getValue(valueId)) { blockArgs.push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value << " id = " << valueId << '\n'); } else { return emitError(unknownLoc, "OpPhi references undefined value!"); } } if (auto branchOp = dyn_cast(op)) { // Replace the previous branch op with a new one with block arguments. opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), blockArgs); branchOp.erase(); } else { return emitError(unknownLoc, "unimplemented terminator for Phi creation"); } LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); } blockPhiInfo.clear(); LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); return success(); } LogicalResult Deserializer::structurizeControlFlow() { LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); while (!blockMergeInfo.empty()) { Block *headerBlock = blockMergeInfo.begin()->first; BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); LLVM_DEBUG(headerBlock->print(llvm::dbgs())); auto *mergeBlock = mergeInfo.mergeBlock; assert(mergeBlock && "merge block cannot be nullptr"); if (!mergeBlock->args_empty()) return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); auto *continueBlock = mergeInfo.continueBlock; if (continueBlock) { LLVM_DEBUG(llvm::dbgs() << "[cf] continue block " << continueBlock << ":\n"); LLVM_DEBUG(continueBlock->print(llvm::dbgs())); } // Erase this case before calling into structurizer, who will update // blockMergeInfo. blockMergeInfo.erase(blockMergeInfo.begin()); if (failed(ControlFlowStructurizer::structurize(unknownLoc, blockMergeInfo, headerBlock, mergeBlock, continueBlock))) return failure(); } LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); return success(); } //===----------------------------------------------------------------------===// // Instruction //===----------------------------------------------------------------------===// Value Deserializer::getValue(uint32_t id) { if (auto constInfo = getConstant(id)) { // Materialize a `spv.constant` op at every use site. return opBuilder.create(unknownLoc, constInfo->second, constInfo->first); } if (auto varOp = getGlobalVariable(id)) { auto addressOfOp = opBuilder.create( unknownLoc, varOp.type(), opBuilder.getSymbolRefAttr(varOp.getOperation())); return addressOfOp.pointer(); } if (auto constOp = getSpecConstant(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constOp.default_value().getType(), opBuilder.getSymbolRefAttr(constOp.getOperation())); return referenceOfOp.reference(); } if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } return valueMap.lookup(id); } LogicalResult Deserializer::sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, Optional expectedOpcode) { auto binarySize = binary.size(); if (curOffset >= binarySize) { return emitError(unknownLoc, "expected ") << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) : "more") << " instruction"; } // For each instruction, get its word count from the first word to slice it // from the stream properly, and then dispatch to the instruction handler. uint32_t wordCount = binary[curOffset] >> 16; if (wordCount == 0) return emitError(unknownLoc, "word count cannot be zero"); uint32_t nextOffset = curOffset + wordCount; if (nextOffset > binarySize) return emitError(unknownLoc, "insufficient words for the last instruction"); opcode = extractOpcode(binary[curOffset]); operands = binary.slice(curOffset + 1, wordCount - 1); curOffset = nextOffset; return success(); } LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, ArrayRef operands, bool deferInstructions) { LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " << spirv::stringifyOpcode(opcode) << "\n"); // First dispatch all the instructions whose opcode does not correspond to // those that have a direct mirror in the SPIR-V dialect switch (opcode) { case spirv::Opcode::OpBitcast: return processBitcast(operands); case spirv::Opcode::OpCapability: return processCapability(operands); case spirv::Opcode::OpExtension: return processExtension(operands); case spirv::Opcode::OpExtInst: return processExtInst(operands); case spirv::Opcode::OpExtInstImport: return processExtInstImport(operands); case spirv::Opcode::OpMemberName: return processMemberName(operands); case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); case spirv::Opcode::OpEntryPoint: case spirv::Opcode::OpExecutionMode: if (deferInstructions) { deferredInstructions.emplace_back(opcode, operands); return success(); } break; case spirv::Opcode::OpVariable: if (isa(opBuilder.getBlock()->getParentOp())) { return processGlobalVariable(operands); } break; case spirv::Opcode::OpName: return processName(operands); case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpString: case spirv::Opcode::OpSource: case spirv::Opcode::OpSourceContinued: case spirv::Opcode::OpSourceExtension: // TODO: This is debug information embedded in the binary which should be // translated into the spv.module. return success(); case spirv::Opcode::OpTypeVoid: case spirv::Opcode::OpTypeBool: case spirv::Opcode::OpTypeInt: case spirv::Opcode::OpTypeFloat: case spirv::Opcode::OpTypeVector: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstant: return processConstant(operands, /*isSpec=*/true); case spirv::Opcode::OpConstantComposite: return processConstantComposite(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantFalse: return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantFalse: return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); case spirv::Opcode::OpMemberDecorate: return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); case spirv::Opcode::OpLabel: return processLabel(operands); case spirv::Opcode::OpBranch: return processBranch(operands); case spirv::Opcode::OpBranchConditional: return processBranchConditional(operands); case spirv::Opcode::OpSelectionMerge: return processSelectionMerge(operands); case spirv::Opcode::OpLoopMerge: return processLoopMerge(operands); case spirv::Opcode::OpPhi: return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: break; } return dispatchToAutogenDeserialization(opcode, operands); } LogicalResult Deserializer::processUndef(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpUndef instruction must have two operands"); } auto type = getType(operands[0]); if (!type) { return emitError(unknownLoc, "unknown type with OpUndef instruction"); } undefMap[operands[1]] = type; return success(); } // TODO(b/130356985): This method is copied from the auto-generated // deserialization function for OpBitcast instruction. This is to avoid // generating a Bitcast operations for cast from signed integer to unsigned // integer and viceversa. MLIR doesn't have native support for this so they both // end up mapping to the same type right now which is illegal according to // OpBitcast semantics (and enforced by the SPIR-V dialect). LogicalResult Deserializer::processBitcast(ArrayRef words) { SmallVector resultTypes; size_t wordIndex = 0; (void)wordIndex; uint32_t valueID = 0; (void)valueID; { if (wordIndex >= words.size()) { return emitError( unknownLoc, "expected result type while deserializing spirv::BitcastOp"); } auto ty = getType(words[wordIndex]); if (!ty) { return emitError(unknownLoc, "unknown type result : ") << words[wordIndex]; } resultTypes.push_back(ty); wordIndex++; if (wordIndex >= words.size()) { return emitError( unknownLoc, "expected result while deserializing spirv::BitcastOp"); } } valueID = words[wordIndex++]; SmallVector operands; SmallVector attributes; if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; } operands.push_back(arg); wordIndex++; } if (wordIndex != words.size()) { return emitError(unknownLoc, "found more operands than expected when deserializing " "spirv::BitcastOp, only ") << wordIndex << " of " << words.size() << " processed"; } if (resultTypes[0] == operands[0].getType() && resultTypes[0].isa()) { // TODO(b/130356985): This check is added to ignore error in Op verification // due to both signed and unsigned integers mapping to the same // type. Without this check this method is same as what is auto-generated. valueMap[valueID] = operands[0]; return success(); } auto op = opBuilder.create(unknownLoc, resultTypes, operands, attributes); (void)op; valueMap[valueID] = op.getResult(); if (decorations.count(valueID)) { auto attrs = decorations[valueID].getAttrs(); attributes.append(attrs.begin(), attrs.end()); } return success(); } LogicalResult Deserializer::processExtInst(ArrayRef operands) { if (operands.size() < 4) { return emitError(unknownLoc, "OpExtInst must have at least 4 operands, result type " ", result , set and instruction opcode"); } if (!extendedInstSets.count(operands[2])) { return emitError(unknownLoc, "undefined set in OpExtInst"); } SmallVector slicedOperands; slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); slicedOperands.append(std::next(operands.begin(), 4), operands.end()); return dispatchToExtensionSetAutogenDeserialization( extendedInstSets[operands[2]], operands[3], slicedOperands); } namespace { template <> LogicalResult Deserializer::processOp(ArrayRef words) { unsigned wordIndex = 0; if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Model specification in OpEntryPoint"); } auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]); if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing in OpEntryPoint"); } // Get the function auto fnID = words[wordIndex++]; // Get the function name auto fnName = decodeStringLiteral(words, wordIndex); // Verify that the function matches the fnName auto parsedFunc = getFunction(fnID); if (!parsedFunc) { return emitError(unknownLoc, "no function matching ") << fnID; } if (parsedFunc.getName() != fnName) { return emitError(unknownLoc, "function name mismatch between OpEntryPoint " "and OpFunction with ") << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); } SmallVector interface; while (wordIndex < words.size()) { auto arg = getGlobalVariable(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "undefined result ") << words[wordIndex] << " while decoding OpEntryPoint"; } interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } opBuilder.create(unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), opBuilder.getArrayAttr(interface)); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef words) { unsigned wordIndex = 0; if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing function result in OpExecutionMode"); } // Get the function to get the name of the function auto fnID = words[wordIndex++]; auto fn = getFunction(fnID); if (!fn) { return emitError(unknownLoc, "no function matching ") << fnID; } // Get the Execution mode if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); } auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); // Get the values SmallVector attrListElems; while (wordIndex < words.size()) { attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); } auto values = opBuilder.getArrayAttr(attrListElems); opBuilder.create( unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() != 3) { return emitError( unknownLoc, "OpControlBarrier must have execution scope , memory scope " "and memory semantics "); } SmallVector argAttrs; for (auto operand : operands) { auto argAttr = getConstantInt(operand); if (!argAttr) { return emitError(unknownLoc, "expected 32-bit integer constant from ") << operand << " for OpControlBarrier"; } argAttrs.push_back(argAttr); } opBuilder.create(unknownLoc, argAttrs[0], argAttrs[1], argAttrs[2]); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() < 3) { return emitError(unknownLoc, "OpFunctionCall must have at least 3 operands"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } auto resultID = operands[1]; auto functionID = operands[2]; auto functionName = getFunctionSymbol(functionID); SmallVector arguments; for (auto operand : llvm::drop_begin(operands, 3)) { auto value = getValue(operand); if (!value) { return emitError(unknownLoc, "unknown ") << operand << " used by OpFunctionCall"; } arguments.push_back(value); } SmallVector resultTypes; if (!isVoidType(resultType)) { resultTypes.push_back(resultType); } auto opFunctionCall = opBuilder.create( unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName), arguments); if (!resultTypes.empty()) { valueMap[resultID] = opFunctionCall.getResult(0); } return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpMemoryBarrier must have memory scope " "and memory semantics "); } SmallVector argAttrs; for (auto operand : operands) { auto argAttr = getConstantInt(operand); if (!argAttr) { return emitError(unknownLoc, "expected 32-bit integer constant from ") << operand << " for OpMemoryBarrier"; } argAttrs.push_back(argAttr); } opBuilder.create(unknownLoc, argAttrs[0], argAttrs[1]); return success(); } // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace Optional spirv::deserialize(ArrayRef binary, MLIRContext *context) { Deserializer deserializer(binary, context); if (failed(deserializer.deserialize())) return llvm::None; return deserializer.collect(); }