summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp2423
1 files changed, 2423 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
new file mode 100644
index 00000000000..17ddc48573a
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -0,0 +1,2423 @@
+//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPIR-V binary to MLIR SPIR-V module deserialization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Serialization.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "spirv-deserialization"
+
+/// Decodes a string literal in `words` starting at `wordIndex`. Update the
+/// latter to point to the position in words after the string literal.
+static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
+ unsigned &wordIndex) {
+ StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
+ wordIndex += str.size() / 4 + 1;
+ return str;
+}
+
+/// Extracts the opcode from the given first word of a SPIR-V instruction.
+static inline spirv::Opcode extractOpcode(uint32_t word) {
+ return static_cast<spirv::Opcode>(word & 0xffff);
+}
+
+/// Returns true if the given `block` is a function entry block.
+static inline bool isFnEntryBlock(Block *block) {
+ return block->isEntryBlock() && isa_and_nonnull<FuncOp>(block->getParentOp());
+}
+
+namespace {
+/// A struct for containing a header block's merge and continue targets.
+///
+/// This struct is used to track original structured control flow info from
+/// SPIR-V blob. This info will be used to create spv.selection/spv.loop
+/// later.
+struct BlockMergeInfo {
+ Block *mergeBlock;
+ Block *continueBlock; // nullptr for spv.selection
+
+ BlockMergeInfo() : mergeBlock(nullptr), continueBlock(nullptr) {}
+ BlockMergeInfo(Block *m, Block *c = nullptr)
+ : mergeBlock(m), continueBlock(c) {}
+};
+
+/// Map from a selection/loop's header block to its merge (and continue) target.
+using BlockMergeInfoMap = DenseMap<Block *, BlockMergeInfo>;
+
+/// A SPIR-V module serializer.
+///
+/// A SPIR-V binary module is a single linear stream of instructions; each
+/// instruction is composed of 32-bit words. The first word of an instruction
+/// records the total number of words of that instruction using the 16
+/// higher-order bits. So this deserializer uses that to get instruction
+/// boundary and parse instructions and build a SPIR-V ModuleOp gradually.
+///
+// TODO(antiagainst): clean up created ops on errors
+class Deserializer {
+public:
+ /// Creates a deserializer for the given SPIR-V `binary` module.
+ /// The SPIR-V ModuleOp will be created into `context.
+ explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
+
+ /// Deserializes the remembered SPIR-V binary module.
+ LogicalResult deserialize();
+
+ /// Collects the final SPIR-V ModuleOp.
+ Optional<spirv::ModuleOp> collect();
+
+private:
+ //===--------------------------------------------------------------------===//
+ // Module structure
+ //===--------------------------------------------------------------------===//
+
+ /// Initializes the `module` ModuleOp in this deserializer instance.
+ spirv::ModuleOp createModuleOp();
+
+ /// Processes SPIR-V module header in `binary`.
+ LogicalResult processHeader();
+
+ /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping
+ /// in the deserializer.
+ LogicalResult processCapability(ArrayRef<uint32_t> operands);
+
+ /// Attaches all collected capabilities to `module` as an attribute.
+ void attachCapabilities();
+
+ /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping
+ /// in the deserializer.
+ LogicalResult processExtension(ArrayRef<uint32_t> words);
+
+ /// Processes the SPIR-V OpExtInstImport with `operands` and updates
+ /// bookkeeping in the deserializer.
+ LogicalResult processExtInstImport(ArrayRef<uint32_t> words);
+
+ /// Attaches all collected extensions to `module` as an attribute.
+ void attachExtensions();
+
+ /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
+ LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
+
+ /// Process SPIR-V OpName with `operands`.
+ LogicalResult processName(ArrayRef<uint32_t> operands);
+
+ /// Processes an OpDecorate instruction.
+ LogicalResult processDecoration(ArrayRef<uint32_t> words);
+
+ // Processes an OpMemberDecorate instruction.
+ LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
+
+ /// Processes an OpMemberName instruction.
+ LogicalResult processMemberName(ArrayRef<uint32_t> words);
+
+ /// Gets the FuncOp associated with a result <id> of OpFunction.
+ FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
+
+ /// Processes the SPIR-V function at the current `offset` into `binary`.
+ /// The operands to the OpFunction instruction is passed in as ``operands`.
+ /// This method processes each instruction inside the function and dispatches
+ /// them to their handler method accordingly.
+ LogicalResult processFunction(ArrayRef<uint32_t> operands);
+
+ /// Processes OpFunctionEnd and finalizes function. This wires up block
+ /// argument created from OpPhi instructions and also structurizes control
+ /// flow.
+ LogicalResult processFunctionEnd(ArrayRef<uint32_t> operands);
+
+ /// Gets the constant's attribute and type associated with the given <id>.
+ Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
+
+ /// Gets the constant's integer attribute with the given <id>. Returns a null
+ /// IntegerAttr if the given is not registered or does not correspond to an
+ /// integer constant.
+ IntegerAttr getConstantInt(uint32_t id);
+
+ /// Returns a symbol to be used for the function name with the given
+ /// result <id>. This tries to use the function's OpName if
+ /// exists; otherwise creates one based on the <id>.
+ std::string getFunctionSymbol(uint32_t id);
+
+ /// Returns a symbol to be used for the specialization constant with the given
+ /// result <id>. This tries to use the specialization constant's OpName if
+ /// exists; otherwise creates one based on the <id>.
+ std::string getSpecConstantSymbol(uint32_t id);
+
+ /// Gets the specialization constant with the given result <id>.
+ spirv::SpecConstantOp getSpecConstant(uint32_t id) {
+ return specConstMap.lookup(id);
+ }
+
+ /// Creates a spirv::SpecConstantOp.
+ spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
+ Attribute defaultValue);
+
+ /// Processes the OpVariable instructions at current `offset` into `binary`.
+ /// It is expected that this method is used for variables that are to be
+ /// defined at module scope and will be deserialized into a spv.globalVariable
+ /// instruction.
+ LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
+
+ /// Gets the global variable associated with a result <id> of OpVariable.
+ spirv::GlobalVariableOp getGlobalVariable(uint32_t id) {
+ return globalVariableMap.lookup(id);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Type
+ //===--------------------------------------------------------------------===//
+
+ /// Gets type for a given result <id>.
+ Type getType(uint32_t id) { return typeMap.lookup(id); }
+
+ /// Get the type associated with the result <id> of an OpUndef.
+ Type getUndefType(uint32_t id) { return undefMap.lookup(id); }
+
+ /// Returns true if the given `type` is for SPIR-V void type.
+ bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+
+ /// Processes a SPIR-V type instruction with given `opcode` and `operands` and
+ /// registers the type into `module`.
+ LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
+
+ LogicalResult processArrayType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processStructType(ArrayRef<uint32_t> operands);
+
+ //===--------------------------------------------------------------------===//
+ // Constant
+ //===--------------------------------------------------------------------===//
+
+ /// Processes a SPIR-V Op{|Spec}Constant instruction with the given
+ /// `operands`. `isSpec` indicates whether this is a specialization constant.
+ LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
+
+ /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
+ /// given `operands`. `isSpec` indicates whether this is a specialization
+ /// constant.
+ LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
+ bool isSpec);
+
+ /// Processes a SPIR-V OpConstantComposite instruction with the given
+ /// `operands`.
+ LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
+ LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+
+ //===--------------------------------------------------------------------===//
+ // Control flow
+ //===--------------------------------------------------------------------===//
+
+ /// Returns the block for the given label <id>.
+ Block *getBlock(uint32_t id) const { return blockMap.lookup(id); }
+
+ // In SPIR-V, structured control flow is explicitly declared using merge
+ // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect,
+ // we use spv.selection and spv.loop to group structured control flow.
+ // The deserializer need to turn structured control flow marked with merge
+ // instructions into using spv.selection/spv.loop ops.
+ //
+ // Because structured control flow can nest and the basic block order have
+ // flexibility, we cannot isolate a structured selection/loop without
+ // deserializing all the blocks. So we use the following approach:
+ //
+ // 1. Deserialize all basic blocks in a function and create MLIR blocks for
+ // them into the function's region. In the meanwhile, keep a map between
+ // selection/loop header blocks to their corresponding merge (and continue)
+ // target blocks.
+ // 2. For each selection/loop header block, recursively get all basic blocks
+ // reachable (except the merge block) and put them in a newly created
+ // spv.selection/spv.loop's region. Structured control flow guarantees
+ // that we enter and exit in structured ways and the construct is nestable.
+ // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge
+ // block and redirect all branches to the old header block to the old
+ // merge block (which contains the spv.selection/spv.loop op now).
+
+ /// For OpPhi instructions, we use block arguments to represent them. OpPhi
+ /// encodes a list of (value, predecessor) pairs. At the time of handling the
+ /// block containing an OpPhi instruction, the predecessor block might not be
+ /// processed yet, also the value sent by it. So we need to defer handling
+ /// the block argument from the predecessors. We use the following approach:
+ ///
+ /// 1. For each OpPhi instruction, add a block argument to the current block
+ /// in construction. Record the block argument in `valueMap` so its uses
+ /// can be resolved. For the list of (value, predecessor) pairs, update
+ /// `blockPhiInfo` for bookkeeping.
+ /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each
+ /// block recorded there to create the proper block arguments on their
+ /// terminators.
+
+ /// A data structure for containing a SPIR-V block's phi info. It will be
+ /// represented as block argument in SPIR-V dialect.
+ using BlockPhiInfo =
+ SmallVector<uint32_t, 2>; // The result <id> of the values sent
+
+ /// Gets or creates the block corresponding to the given label <id>. The newly
+ /// created block will always be placed at the end of the current function.
+ Block *getOrCreateBlock(uint32_t id);
+
+ LogicalResult processBranch(ArrayRef<uint32_t> operands);
+
+ LogicalResult processBranchConditional(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpLabel instruction with the given `operands`.
+ LogicalResult processLabel(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`.
+ LogicalResult processSelectionMerge(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`.
+ LogicalResult processLoopMerge(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpPhi instruction with the given `operands`.
+ LogicalResult processPhi(ArrayRef<uint32_t> operands);
+
+ /// Creates block arguments on predecessors previously recorded when handling
+ /// OpPhi instructions.
+ LogicalResult wireUpBlockArgument();
+
+ /// Extracts blocks belonging to a structured selection/loop into a
+ /// spv.selection/spv.loop op. This method iterates until all blocks
+ /// declared as selection/loop headers are handled.
+ LogicalResult structurizeControlFlow();
+
+ //===--------------------------------------------------------------------===//
+ // Instruction
+ //===--------------------------------------------------------------------===//
+
+ /// Get the Value associated with a result <id>.
+ ///
+ /// This method materializes normal constants and inserts "casting" ops
+ /// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA
+ /// value for handling uses of module scope constants/variables in functions.
+ Value getValue(uint32_t id);
+
+ /// Slices the first instruction out of `binary` and returns its opcode and
+ /// operands via `opcode` and `operands` respectively. Returns failure if
+ /// there is no more remaining instructions (`expectedOpcode` will be used to
+ /// compose the error message) or the next instruction is malformed.
+ LogicalResult
+ sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
+ Optional<spirv::Opcode> expectedOpcode = llvm::None);
+
+ /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
+ /// This method is the main entrance for handling SPIR-V instruction; it
+ /// checks the instruction opcode and dispatches to the corresponding handler.
+ /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode)
+ /// might need to be deferred, since they contain forward references to <id>s
+ /// in the deserialized binary, but module in SPIR-V dialect expects these to
+ /// be ssa-uses.
+ LogicalResult processInstruction(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands,
+ bool deferInstructions = true);
+
+ /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current
+ /// insertion point.
+ LogicalResult processUndef(ArrayRef<uint32_t> operands);
+
+ /// Processes an OpBitcast instruction.
+ LogicalResult processBitcast(ArrayRef<uint32_t> words);
+
+ /// Method to dispatch to the specialized deserialization function for an
+ /// operation in SPIR-V dialect that is a mirror of an instruction in the
+ /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
+ /// all operations in SPIR-V dialect that have hasOpcode == 1.
+ LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
+ ArrayRef<uint32_t> words);
+
+ /// Processes a SPIR-V OpExtInst with given `operands`. This slices the
+ /// entries of `operands` that specify the extended instruction set <id> and
+ /// the instruction opcode. The op deserializer is then invoked using the
+ /// other entries.
+ LogicalResult processExtInst(ArrayRef<uint32_t> operands);
+
+ /// Dispatches the deserialization of extended instruction set operation based
+ /// on the extended instruction set name, and instruction opcode. This is
+ /// autogenerated from ODS.
+ LogicalResult
+ dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName,
+ uint32_t instructionID,
+ ArrayRef<uint32_t> words);
+
+ /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
+ /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
+ /// == 1 and autogenSerialization == 1 in ODS.
+ template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
+ return emitError(unknownLoc, "unsupported deserialization for ")
+ << OpTy::getOperationName() << " op";
+ }
+
+private:
+ /// The SPIR-V binary module.
+ ArrayRef<uint32_t> binary;
+
+ /// The current word offset into the binary module.
+ unsigned curOffset = 0;
+
+ /// MLIRContext to create SPIR-V ModuleOp into.
+ MLIRContext *context;
+
+ // TODO(antiagainst): create Location subclass for binary blob
+ Location unknownLoc;
+
+ /// The SPIR-V ModuleOp.
+ Optional<spirv::ModuleOp> module;
+
+ /// The current function under construction.
+ Optional<FuncOp> curFunction;
+
+ /// The current block under construction.
+ Block *curBlock = nullptr;
+
+ OpBuilder opBuilder;
+
+ /// The list of capabilities used by the module.
+ llvm::SmallSetVector<spirv::Capability, 4> capabilities;
+
+ /// The list of extensions used by the module.
+ llvm::SmallSetVector<StringRef, 2> extensions;
+
+ // Result <id> to type mapping.
+ DenseMap<uint32_t, Type> typeMap;
+
+ // Result <id> to constant attribute and type mapping.
+ ///
+ /// In the SPIR-V binary format, all constants are placed in the module and
+ /// shared by instructions at module level and in subsequent functions. But in
+ /// the SPIR-V dialect, we materialize the constant to where it's used in the
+ /// function. So when seeing a constant instruction in the binary format, we
+ /// don't immediately emit a constant op into the module, we keep its value
+ /// (and type) here. Later when it's used, we materialize the constant.
+ DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
+
+ // Result <id> to variable mapping.
+ DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
+
+ // Result <id> to variable mapping.
+ DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+
+ // Result <id> to function mapping.
+ DenseMap<uint32_t, FuncOp> funcMap;
+
+ // Result <id> to block mapping.
+ DenseMap<uint32_t, Block *> blockMap;
+
+ // Header block to its merge (and continue) target mapping.
+ BlockMergeInfoMap blockMergeInfo;
+
+ // Block to its phi (block argument) mapping.
+ DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
+
+ // Result <id> to value mapping.
+ DenseMap<uint32_t, Value> valueMap;
+
+ // Mapping from result <id> to undef value of a type.
+ DenseMap<uint32_t, Type> undefMap;
+
+ // Result <id> to name mapping.
+ DenseMap<uint32_t, StringRef> nameMap;
+
+ // Result <id> to decorations mapping.
+ DenseMap<uint32_t, NamedAttributeList> decorations;
+
+ // Result <id> to type decorations.
+ DenseMap<uint32_t, uint32_t> typeDecorations;
+
+ // Result <id> to member decorations.
+ // decorated-struct-type-<id> ->
+ // (struct-member-index -> (decoration -> decoration-operands))
+ DenseMap<uint32_t,
+ DenseMap<uint32_t, DenseMap<spirv::Decoration, ArrayRef<uint32_t>>>>
+ memberDecorationMap;
+
+ // Result <id> to member name.
+ // struct-type-<id> -> (struct-member-index -> name)
+ DenseMap<uint32_t, DenseMap<uint32_t, StringRef>> memberNameMap;
+
+ // Result <id> to extended instruction set name.
+ DenseMap<uint32_t, StringRef> extendedInstSets;
+
+ // List of instructions that are processed in a deferred fashion (after an
+ // initial processing of the entire binary). Some operations like
+ // OpEntryPoint, and OpExecutionMode use forward references to function
+ // <id>s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and
+ // spv.ExecutionMode) need these references resolved. So these instructions
+ // are deserialized and stored for processing once the entire binary is
+ // processed.
+ SmallVector<std::pair<spirv::Opcode, ArrayRef<uint32_t>>, 4>
+ deferredInstructions;
+};
+} // namespace
+
+Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
+ : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
+ module(createModuleOp()), opBuilder(module->body()) {}
+
+LogicalResult Deserializer::deserialize() {
+ LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
+
+ if (failed(processHeader()))
+ return failure();
+
+ spirv::Opcode opcode = spirv::Opcode::OpNop;
+ ArrayRef<uint32_t> operands;
+ auto binarySize = binary.size();
+ while (curOffset < binarySize) {
+ // Slice the next instruction out and populate `opcode` and `operands`.
+ // Internally this also updates `curOffset`.
+ if (failed(sliceInstruction(opcode, operands)))
+ return failure();
+
+ if (failed(processInstruction(opcode, operands)))
+ return failure();
+ }
+
+ assert(curOffset == binarySize &&
+ "deserializer should never index beyond the binary end");
+
+ for (auto &deferred : deferredInstructions) {
+ if (failed(processInstruction(deferred.first, deferred.second, false))) {
+ return failure();
+ }
+ }
+
+ // Attaches the capabilities/extensions as an attribute to the module.
+ attachCapabilities();
+ attachExtensions();
+
+ LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n");
+ return success();
+}
+
+Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
+
+//===----------------------------------------------------------------------===//
+// Module structure
+//===----------------------------------------------------------------------===//
+
+spirv::ModuleOp Deserializer::createModuleOp() {
+ Builder builder(context);
+ OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
+ // TODO(antiagainst): use target environment to select the version
+ state.addAttribute("major_version", builder.getI32IntegerAttr(1));
+ state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
+ spirv::ModuleOp::build(&builder, state);
+ return cast<spirv::ModuleOp>(Operation::create(state));
+}
+
+LogicalResult Deserializer::processHeader() {
+ if (binary.size() < spirv::kHeaderWordCount)
+ return emitError(unknownLoc,
+ "SPIR-V binary module must have a 5-word header");
+
+ if (binary[0] != spirv::kMagicNumber)
+ return emitError(unknownLoc, "incorrect magic number");
+
+ // TODO(antiagainst): generator number, bound, schema
+ curOffset = spirv::kHeaderWordCount;
+ return success();
+}
+
+LogicalResult Deserializer::processCapability(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 1)
+ return emitError(unknownLoc, "OpMemoryModel must have one parameter");
+
+ auto cap = spirv::symbolizeCapability(operands[0]);
+ if (!cap)
+ return emitError(unknownLoc, "unknown capability: ") << operands[0];
+
+ capabilities.insert(*cap);
+ return success();
+}
+
+void Deserializer::attachCapabilities() {
+ if (capabilities.empty())
+ return;
+
+ SmallVector<StringRef, 2> caps;
+ caps.reserve(capabilities.size());
+
+ for (auto cap : capabilities) {
+ caps.push_back(spirv::stringifyCapability(cap));
+ }
+
+ module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps));
+}
+
+LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> words) {
+ if (words.empty()) {
+ return emitError(
+ unknownLoc,
+ "OpExtension must have a literal string for the extension name");
+ }
+
+ unsigned wordIndex = 0;
+ StringRef extName = decodeStringLiteral(words, wordIndex);
+ if (wordIndex != words.size()) {
+ return emitError(unknownLoc,
+ "unexpected trailing words in OpExtension instruction");
+ }
+
+ extensions.insert(extName);
+ return success();
+}
+
+LogicalResult Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
+ if (words.size() < 2) {
+ return emitError(unknownLoc,
+ "OpExtInstImport must have a result <id> and a literal "
+ "string for the extended instruction set name");
+ }
+
+ unsigned wordIndex = 1;
+ extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
+ if (wordIndex != words.size()) {
+ return emitError(unknownLoc,
+ "unexpected trailing words in OpExtInstImport");
+ }
+ return success();
+}
+
+void Deserializer::attachExtensions() {
+ if (extensions.empty())
+ return;
+
+ module->setAttr("extensions",
+ opBuilder.getStrArrayAttr(extensions.getArrayRef()));
+}
+
+LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2)
+ return emitError(unknownLoc, "OpMemoryModel must have two operands");
+
+ module->setAttr(
+ "addressing_model",
+ opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
+ module->setAttr(
+ "memory_model",
+ opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
+
+ return success();
+}
+
+LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
+ // TODO : This function should also be auto-generated. For now, since only a
+ // few decorations are processed/handled in a meaningful manner, going with a
+ // manual implementation.
+ if (words.size() < 2) {
+ return emitError(
+ unknownLoc, "OpDecorate must have at least result <id> and Decoration");
+ }
+ auto decorationName =
+ stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
+ if (decorationName.empty()) {
+ return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
+ }
+ auto attrName = convertToSnakeCase(decorationName);
+ auto symbol = opBuilder.getIdentifier(attrName);
+ switch (static_cast<spirv::Decoration>(words[1])) {
+ case spirv::Decoration::DescriptorSet:
+ case spirv::Decoration::Binding:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
+ break;
+ case spirv::Decoration::BuiltIn:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, opBuilder.getStringAttr(
+ stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
+ break;
+ case spirv::Decoration::ArrayStride:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ typeDecorations[words[0]] = words[2];
+ break;
+ case spirv::Decoration::Block:
+ case spirv::Decoration::BufferBlock:
+ if (words.size() != 2) {
+ return emitError(unknownLoc, "OpDecoration with ")
+ << decorationName << "needs a single target <id>";
+ }
+ // Block decoration does not affect spv.struct type, but is still stored for
+ // verification.
+ // TODO: Update StructType to contain this information since
+ // it is needed for many validation rules.
+ decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
+ break;
+ case spirv::Decoration::SpecId:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecoration with ")
+ << decorationName << "needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
+ break;
+ default:
+ return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
+ }
+ return success();
+}
+
+LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
+ // The binary layout of OpMemberDecorate is different comparing to OpDecorate
+ if (words.size() < 3) {
+ return emitError(unknownLoc,
+ "OpMemberDecorate must have at least 3 operands");
+ }
+
+ auto decoration = static_cast<spirv::Decoration>(words[2]);
+ if (decoration == spirv::Decoration::Offset && words.size() != 4) {
+ return emitError(unknownLoc,
+ " missing offset specification in OpMemberDecorate with "
+ "Offset decoration");
+ }
+ ArrayRef<uint32_t> decorationOperands;
+ if (words.size() > 3) {
+ decorationOperands = words.slice(3);
+ }
+ memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
+ return success();
+}
+
+LogicalResult Deserializer::processMemberName(ArrayRef<uint32_t> words) {
+ if (words.size() < 3) {
+ return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
+ }
+ unsigned wordIndex = 2;
+ auto name = decodeStringLiteral(words, wordIndex);
+ if (wordIndex != words.size()) {
+ return emitError(unknownLoc,
+ "unexpected trailing words in OpMemberName instruction");
+ }
+ memberNameMap[words[0]][words[1]] = name;
+ return success();
+}
+
+LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
+ if (curFunction) {
+ return emitError(unknownLoc, "found function inside function");
+ }
+
+ // Get the result type
+ if (operands.size() != 4) {
+ return emitError(unknownLoc, "OpFunction must have 4 parameters");
+ }
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ if (funcMap.count(operands[1])) {
+ return emitError(unknownLoc, "duplicate function definition/declaration");
+ }
+
+ auto functionControl = spirv::symbolizeFunctionControl(operands[2]);
+ if (!functionControl) {
+ return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
+ }
+ if (functionControl.getValue() != spirv::FunctionControl::None) {
+ /// TODO : Handle different function controls
+ return emitError(unknownLoc, "unhandled Function Control: '")
+ << spirv::stringifyFunctionControl(functionControl.getValue())
+ << "'";
+ }
+
+ Type fnType = getType(operands[3]);
+ if (!fnType || !fnType.isa<FunctionType>()) {
+ return emitError(unknownLoc, "unknown function type from <id> ")
+ << operands[3];
+ }
+ auto functionType = fnType.cast<FunctionType>();
+
+ if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
+ (functionType.getNumResults() == 1 &&
+ functionType.getResult(0) != resultType)) {
+ return emitError(unknownLoc, "mismatch in function type ")
+ << functionType << " and return type " << resultType << " specified";
+ }
+
+ std::string fnName = getFunctionSymbol(operands[1]);
+ auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
+ ArrayRef<NamedAttribute>());
+ curFunction = funcMap[operands[1]] = funcOp;
+ LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = "
+ << fnType << ", id = " << operands[1] << ") --\n");
+ auto *entryBlock = funcOp.addEntryBlock();
+ LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock
+ << "\n");
+
+ // Parse the op argument instructions
+ if (functionType.getNumInputs()) {
+ for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
+ auto argType = functionType.getInput(i);
+ spirv::Opcode opcode = spirv::Opcode::OpNop;
+ ArrayRef<uint32_t> operands;
+ if (failed(sliceInstruction(opcode, operands,
+ spirv::Opcode::OpFunctionParameter))) {
+ return failure();
+ }
+ if (opcode != spirv::Opcode::OpFunctionParameter) {
+ return emitError(
+ unknownLoc,
+ "missing OpFunctionParameter instruction for argument ")
+ << i;
+ }
+ if (operands.size() != 2) {
+ return emitError(
+ unknownLoc,
+ "expected result type and result <id> for OpFunctionParameter");
+ }
+ auto argDefinedType = getType(operands[0]);
+ if (!argDefinedType || argDefinedType != argType) {
+ return emitError(unknownLoc,
+ "mismatch in argument type between function type "
+ "definition ")
+ << functionType << " and argument type definition "
+ << argDefinedType << " at argument " << i;
+ }
+ if (getValue(operands[1])) {
+ return emitError(unknownLoc, "duplicate definition of result <id> '")
+ << operands[1];
+ }
+ auto argValue = funcOp.getArgument(i);
+ valueMap[operands[1]] = argValue;
+ }
+ }
+
+ // RAII guard to reset the insertion point to the module's region after
+ // deserializing the body of this function.
+ OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+ spirv::Opcode opcode = spirv::Opcode::OpNop;
+ ArrayRef<uint32_t> instOperands;
+
+ // Special handling for the entry block. We need to make sure it starts with
+ // an OpLabel instruction. The entry block takes the same parameters as the
+ // function. All other blocks do not take any parameter. We have already
+ // created the entry block, here we need to register it to the correct label
+ // <id>.
+ if (failed(sliceInstruction(opcode, instOperands,
+ spirv::Opcode::OpFunctionEnd))) {
+ return failure();
+ }
+ if (opcode == spirv::Opcode::OpFunctionEnd) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "-- completed function '" << fnName << "' (type = " << fnType
+ << ", id = " << operands[1] << ") --\n");
+ return processFunctionEnd(instOperands);
+ }
+ if (opcode != spirv::Opcode::OpLabel) {
+ return emitError(unknownLoc, "a basic block must start with OpLabel");
+ }
+ if (instOperands.size() != 1) {
+ return emitError(unknownLoc, "OpLabel should only have result <id>");
+ }
+ blockMap[instOperands[0]] = entryBlock;
+ if (failed(processLabel(instOperands))) {
+ return failure();
+ }
+
+ // Then process all the other instructions in the function until we hit
+ // OpFunctionEnd.
+ while (succeeded(sliceInstruction(opcode, instOperands,
+ spirv::Opcode::OpFunctionEnd)) &&
+ opcode != spirv::Opcode::OpFunctionEnd) {
+ if (failed(processInstruction(opcode, instOperands))) {
+ return failure();
+ }
+ }
+ if (opcode != spirv::Opcode::OpFunctionEnd) {
+ return failure();
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = "
+ << fnType << ", id = " << operands[1] << ") --\n");
+ return processFunctionEnd(instOperands);
+}
+
+LogicalResult Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
+ // Process OpFunctionEnd.
+ if (!operands.empty()) {
+ return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
+ }
+
+ // Wire up block arguments from OpPhi instructions.
+ // Put all structured control flow in spv.selection/spv.loop ops.
+ if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
+ return failure();
+ }
+
+ curBlock = nullptr;
+ curFunction = llvm::None;
+
+ return success();
+}
+
+Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
+ auto constIt = constantMap.find(id);
+ if (constIt == constantMap.end())
+ return llvm::None;
+ return constIt->getSecond();
+}
+
+std::string Deserializer::getFunctionSymbol(uint32_t id) {
+ auto funcName = nameMap.lookup(id).str();
+ if (funcName.empty()) {
+ funcName = "spirv_fn_" + std::to_string(id);
+ }
+ return funcName;
+}
+
+std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
+ auto constName = nameMap.lookup(id).str();
+ if (constName.empty()) {
+ constName = "spirv_spec_const_" + std::to_string(id);
+ }
+ return constName;
+}
+
+spirv::SpecConstantOp Deserializer::createSpecConstant(Location loc,
+ uint32_t resultID,
+ Attribute defaultValue) {
+ auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+ auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
+ defaultValue);
+ if (decorations.count(resultID)) {
+ for (auto attr : decorations[resultID].getAttrs())
+ op.setAttr(attr.first, attr.second);
+ }
+ specConstMap[resultID] = op;
+ return op;
+}
+
+LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
+ unsigned wordIndex = 0;
+ if (operands.size() < 3) {
+ return emitError(
+ unknownLoc,
+ "OpVariable needs at least 3 operands, type, <id> and storage class");
+ }
+
+ // Result Type.
+ auto type = getType(operands[wordIndex]);
+ if (!type) {
+ return emitError(unknownLoc, "unknown result type <id> : ")
+ << operands[wordIndex];
+ }
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return emitError(unknownLoc,
+ "expected a result type <id> to be a spv.ptr, found : ")
+ << type;
+ }
+ wordIndex++;
+
+ // Result <id>.
+ auto variableID = operands[wordIndex];
+ auto variableName = nameMap.lookup(variableID).str();
+ if (variableName.empty()) {
+ variableName = "spirv_var_" + std::to_string(variableID);
+ }
+ wordIndex++;
+
+ // Storage class.
+ auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
+ if (ptrType.getStorageClass() != storageClass) {
+ return emitError(unknownLoc, "mismatch in storage class of pointer type ")
+ << type << " and that specified in OpVariable instruction : "
+ << stringifyStorageClass(storageClass);
+ }
+ wordIndex++;
+
+ // Initializer.
+ FlatSymbolRefAttr initializer = nullptr;
+ if (wordIndex < operands.size()) {
+ auto initializerOp = getGlobalVariable(operands[wordIndex]);
+ if (!initializerOp) {
+ return emitError(unknownLoc, "unknown <id> ")
+ << operands[wordIndex] << "used as initializer";
+ }
+ wordIndex++;
+ initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
+ }
+ if (wordIndex != operands.size()) {
+ return emitError(unknownLoc,
+ "found more operands than expected when deserializing "
+ "OpVariable instruction, only ")
+ << wordIndex << " of " << operands.size() << " processed";
+ }
+ auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
+ unknownLoc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
+ initializer);
+
+ // Decorations.
+ if (decorations.count(variableID)) {
+ for (auto attr : decorations[variableID].getAttrs()) {
+ varOp.setAttr(attr.first, attr.second);
+ }
+ }
+ globalVariableMap[variableID] = varOp;
+ return success();
+}
+
+IntegerAttr Deserializer::getConstantInt(uint32_t id) {
+ auto constInfo = getConstant(id);
+ if (!constInfo) {
+ return nullptr;
+ }
+ return constInfo->first.dyn_cast<IntegerAttr>();
+}
+
+LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc, "OpName needs at least 2 operands");
+ }
+ if (!nameMap.lookup(operands[0]).empty()) {
+ return emitError(unknownLoc, "duplicate name found for result <id> ")
+ << operands[0];
+ }
+ unsigned wordIndex = 1;
+ StringRef name = decodeStringLiteral(operands, wordIndex);
+ if (wordIndex != operands.size()) {
+ return emitError(unknownLoc,
+ "unexpected trailing words in OpName instruction");
+ }
+ nameMap[operands[0]] = name;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::processType(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands) {
+ if (operands.empty()) {
+ return emitError(unknownLoc, "type instruction with opcode ")
+ << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
+ }
+
+ /// TODO: Types might be forward declared in some instructions and need to be
+ /// handled appropriately.
+ if (typeMap.count(operands[0])) {
+ return emitError(unknownLoc, "duplicate definition for result <id> ")
+ << operands[0];
+ }
+
+ switch (opcode) {
+ case spirv::Opcode::OpTypeVoid:
+ if (operands.size() != 1) {
+ return emitError(unknownLoc, "OpTypeVoid must have no parameters");
+ }
+ typeMap[operands[0]] = opBuilder.getNoneType();
+ break;
+ case spirv::Opcode::OpTypeBool:
+ if (operands.size() != 1) {
+ return emitError(unknownLoc, "OpTypeBool must have no parameters");
+ }
+ typeMap[operands[0]] = opBuilder.getI1Type();
+ break;
+ case spirv::Opcode::OpTypeInt:
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
+ }
+ // TODO: Ignoring the signedness right now. Need to handle this effectively
+ // in the MLIR representation.
+ typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]);
+ break;
+ case spirv::Opcode::OpTypeFloat: {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+ }
+ Type floatTy;
+ switch (operands[1]) {
+ case 16:
+ floatTy = opBuilder.getF16Type();
+ break;
+ case 32:
+ floatTy = opBuilder.getF32Type();
+ break;
+ case 64:
+ floatTy = opBuilder.getF64Type();
+ break;
+ default:
+ return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
+ << operands[1];
+ }
+ typeMap[operands[0]] = floatTy;
+ } break;
+ case spirv::Opcode::OpTypeVector: {
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc,
+ "OpTypeVector must have element type and count parameters");
+ }
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
+ << operands[1];
+ }
+ typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+ } break;
+ case spirv::Opcode::OpTypePointer: {
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "OpTypePointer must have two parameters");
+ }
+ auto pointeeType = getType(operands[2]);
+ if (!pointeeType) {
+ return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
+ << operands[2];
+ }
+ auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
+ typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
+ } break;
+ case spirv::Opcode::OpTypeArray:
+ return processArrayType(operands);
+ case spirv::Opcode::OpTypeFunction:
+ return processFunctionType(operands);
+ case spirv::Opcode::OpTypeRuntimeArray:
+ return processRuntimeArrayType(operands);
+ case spirv::Opcode::OpTypeStruct:
+ return processStructType(operands);
+ default:
+ return emitError(unknownLoc, "unhandled type instruction");
+ }
+ return success();
+}
+
+LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 3) {
+ return emitError(unknownLoc,
+ "OpTypeArray must have element type and count parameters");
+ }
+
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
+ << operands[1];
+ }
+
+ unsigned count = 0;
+ // TODO(antiagainst): The count can also come frome a specialization constant.
+ auto countInfo = getConstant(operands[2]);
+ if (!countInfo) {
+ return emitError(unknownLoc, "OpTypeArray count <id> ")
+ << operands[2] << "can only come from normal constant right now";
+ }
+
+ if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
+ count = intVal.getInt();
+ } else {
+ return emitError(unknownLoc, "OpTypeArray count must come from a "
+ "scalar integer constant instruction");
+ }
+
+ typeMap[operands[0]] = spirv::ArrayType::get(
+ elementTy, count, typeDecorations.lookup(operands[0]));
+ return success();
+}
+
+LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
+ assert(!operands.empty() && "No operands for processing function type");
+ if (operands.size() == 1) {
+ return emitError(unknownLoc, "missing return type for OpTypeFunction");
+ }
+ auto returnType = getType(operands[1]);
+ if (!returnType) {
+ return emitError(unknownLoc, "unknown return type in OpTypeFunction");
+ }
+ SmallVector<Type, 1> argTypes;
+ for (size_t i = 2, e = operands.size(); i < e; ++i) {
+ auto ty = getType(operands[i]);
+ if (!ty) {
+ return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
+ }
+ argTypes.push_back(ty);
+ }
+ ArrayRef<Type> returnTypes;
+ if (!isVoidType(returnType)) {
+ returnTypes = llvm::makeArrayRef(returnType);
+ }
+ typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
+ return success();
+}
+
+LogicalResult
+Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
+ }
+ Type memberType = getType(operands[1]);
+ if (!memberType) {
+ return emitError(unknownLoc,
+ "OpTypeRuntimeArray references undefined <id> ")
+ << operands[1];
+ }
+ typeMap[operands[0]] = spirv::RuntimeArrayType::get(memberType);
+ return success();
+}
+
+LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
+ if (operands.empty()) {
+ return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
+ }
+ if (operands.size() == 1) {
+ // Handle empty struct.
+ typeMap[operands[0]] = spirv::StructType::getEmpty(context);
+ return success();
+ }
+
+ SmallVector<Type, 0> memberTypes;
+ for (auto op : llvm::drop_begin(operands, 1)) {
+ Type memberType = getType(op);
+ if (!memberType) {
+ return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
+ << op;
+ }
+ memberTypes.push_back(memberType);
+ }
+
+ SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
+ SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ if (memberDecorationMap.count(operands[0])) {
+ auto &allMemberDecorations = memberDecorationMap[operands[0]];
+ for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
+ if (allMemberDecorations.count(memberIndex)) {
+ for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
+ // Check for offset.
+ if (memberDecoration.first == spirv::Decoration::Offset) {
+ // If layoutInfo is empty, resize to the number of members;
+ if (layoutInfo.empty()) {
+ layoutInfo.resize(memberTypes.size());
+ }
+ layoutInfo[memberIndex] = memberDecoration.second[0];
+ } else {
+ if (!memberDecoration.second.empty()) {
+ return emitError(unknownLoc,
+ "unhandled OpMemberDecoration with decoration ")
+ << stringifyDecoration(memberDecoration.first)
+ << " which has additional operands";
+ }
+ memberDecorationsInfo.emplace_back(memberIndex,
+ memberDecoration.first);
+ }
+ }
+ }
+ }
+ }
+ typeMap[operands[0]] =
+ spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
+ // TODO(ravishankarm): Update StructType to have member name as attribute as
+ // well.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Constant
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
+ bool isSpec) {
+ StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
+
+ if (operands.size() < 2) {
+ return emitError(unknownLoc)
+ << opname << " must have type <id> and result <id>";
+ }
+ if (operands.size() < 3) {
+ return emitError(unknownLoc)
+ << opname << " must have at least 1 more parameter";
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
+ if (bitwidth == 64) {
+ if (operands.size() == 4) {
+ return success();
+ }
+ return emitError(unknownLoc)
+ << opname << " should have 2 parameters for 64-bit values";
+ }
+ if (bitwidth <= 32) {
+ if (operands.size() == 3) {
+ return success();
+ }
+
+ return emitError(unknownLoc)
+ << opname
+ << " should have 1 parameter for values with no more than 32 bits";
+ }
+ return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
+ << bitwidth;
+ };
+
+ auto resultID = operands[1];
+
+ if (auto intType = resultType.dyn_cast<IntegerType>()) {
+ auto bitwidth = intType.getWidth();
+ if (failed(checkOperandSizeForBitwidth(bitwidth))) {
+ return failure();
+ }
+
+ APInt value;
+ if (bitwidth == 64) {
+ // 64-bit integers are represented with two SPIR-V words. According to
+ // SPIR-V spec: "When the type’s bit width is larger than one word, the
+ // literal’s low-order words appear first."
+ struct DoubleWord {
+ uint32_t word1;
+ uint32_t word2;
+ } words = {operands[2], operands[3]};
+ value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
+ } else if (bitwidth <= 32) {
+ value = APInt(bitwidth, operands[2], /*isSigned=*/true);
+ }
+
+ auto attr = opBuilder.getIntegerAttr(intType, value);
+
+ if (isSpec) {
+ createSpecConstant(unknownLoc, resultID, attr);
+ } else {
+ // For normal constants, we just record the attribute (and its type) for
+ // later materialization at use sites.
+ constantMap.try_emplace(resultID, attr, intType);
+ }
+
+ return success();
+ }
+
+ if (auto floatType = resultType.dyn_cast<FloatType>()) {
+ auto bitwidth = floatType.getWidth();
+ if (failed(checkOperandSizeForBitwidth(bitwidth))) {
+ return failure();
+ }
+
+ APFloat value(0.f);
+ if (floatType.isF64()) {
+ // Double values are represented with two SPIR-V words. According to
+ // SPIR-V spec: "When the type’s bit width is larger than one word, the
+ // literal’s low-order words appear first."
+ struct DoubleWord {
+ uint32_t word1;
+ uint32_t word2;
+ } words = {operands[2], operands[3]};
+ value = APFloat(llvm::bit_cast<double>(words));
+ } else if (floatType.isF32()) {
+ value = APFloat(llvm::bit_cast<float>(operands[2]));
+ } else if (floatType.isF16()) {
+ APInt data(16, operands[2]);
+ value = APFloat(APFloat::IEEEhalf(), data);
+ }
+
+ auto attr = opBuilder.getFloatAttr(floatType, value);
+ if (isSpec) {
+ createSpecConstant(unknownLoc, resultID, attr);
+ } else {
+ // For normal constants, we just record the attribute (and its type) for
+ // later materialization at use sites.
+ constantMap.try_emplace(resultID, attr, floatType);
+ }
+
+ return success();
+ }
+
+ return emitError(unknownLoc, "OpConstant can only generate values of "
+ "scalar integer or floating-point type");
+}
+
+LogicalResult Deserializer::processConstantBool(bool isTrue,
+ ArrayRef<uint32_t> operands,
+ bool isSpec) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "Op")
+ << (isSpec ? "Spec" : "") << "Constant"
+ << (isTrue ? "True" : "False")
+ << " must have type <id> and result <id>";
+ }
+
+ auto attr = opBuilder.getBoolAttr(isTrue);
+ auto resultID = operands[1];
+ if (isSpec) {
+ createSpecConstant(unknownLoc, resultID, attr);
+ } else {
+ // For normal constants, we just record the attribute (and its type) for
+ // later materialization at use sites.
+ constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
+ }
+
+ return success();
+}
+
+LogicalResult
+Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc,
+ "OpConstantComposite must have type <id> and result <id>");
+ }
+ if (operands.size() < 3) {
+ return emitError(unknownLoc,
+ "OpConstantComposite must have at least 1 parameter");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ SmallVector<Attribute, 4> elements;
+ elements.reserve(operands.size() - 2);
+ for (unsigned i = 2, e = operands.size(); i < e; ++i) {
+ auto elementInfo = getConstant(operands[i]);
+ if (!elementInfo) {
+ return emitError(unknownLoc, "OpConstantComposite component <id> ")
+ << operands[i] << " must come from a normal constant";
+ }
+ elements.push_back(elementInfo->first);
+ }
+
+ auto resultID = operands[1];
+ if (auto vectorType = resultType.dyn_cast<VectorType>()) {
+ auto attr = DenseElementsAttr::get(vectorType, elements);
+ // For normal constants, we just record the attribute (and its type) for
+ // later materialization at use sites.
+ constantMap.try_emplace(resultID, attr, resultType);
+ } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
+ auto attr = opBuilder.getArrayAttr(elements);
+ constantMap.try_emplace(resultID, attr, resultType);
+ } else {
+ return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
+ << resultType;
+ }
+
+ return success();
+}
+
+LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc,
+ "OpConstantNull must have type <id> and result <id>");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+ if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
+ resultType.isa<VectorType>()) {
+ auto attr = opBuilder.getZeroAttr(resultType);
+ // For normal constants, we just record the attribute (and its type) for
+ // later materialization at use sites.
+ constantMap.try_emplace(resultID, attr, resultType);
+ return success();
+ }
+
+ return emitError(unknownLoc, "unsupported OpConstantNull type: ")
+ << resultType;
+}
+
+//===----------------------------------------------------------------------===//
+// Control flow
+//===----------------------------------------------------------------------===//
+
+Block *Deserializer::getOrCreateBlock(uint32_t id) {
+ if (auto *block = getBlock(id)) {
+ LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id
+ << " @ " << block << "\n");
+ return block;
+ }
+
+ // We don't know where this block will be placed finally (in a spv.selection
+ // or spv.loop or function). Create it into the function for now and sort
+ // out the proper place later.
+ auto *block = curFunction->addBlock();
+ LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ "
+ << block << "\n");
+ return blockMap[id] = block;
+}
+
+LogicalResult Deserializer::processBranch(ArrayRef<uint32_t> operands) {
+ if (!curBlock) {
+ return emitError(unknownLoc, "OpBranch must appear inside a block");
+ }
+
+ if (operands.size() != 1) {
+ return emitError(unknownLoc, "OpBranch must take exactly one target label");
+ }
+
+ auto *target = getOrCreateBlock(operands[0]);
+ opBuilder.create<spirv::BranchOp>(unknownLoc, target);
+
+ return success();
+}
+
+LogicalResult
+Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
+ if (!curBlock) {
+ return emitError(unknownLoc,
+ "OpBranchConditional must appear inside a block");
+ }
+
+ if (operands.size() != 3 && operands.size() != 5) {
+ return emitError(unknownLoc,
+ "OpBranchConditional must have condition, true label, "
+ "false label, and optionally two branch weights");
+ }
+
+ auto condition = getValue(operands[0]);
+ auto *trueBlock = getOrCreateBlock(operands[1]);
+ auto *falseBlock = getOrCreateBlock(operands[2]);
+
+ Optional<std::pair<uint32_t, uint32_t>> weights;
+ if (operands.size() == 5) {
+ weights = std::make_pair(operands[3], operands[4]);
+ }
+
+ opBuilder.create<spirv::BranchConditionalOp>(
+ unknownLoc, condition, trueBlock,
+ /*trueArguments=*/ArrayRef<Value>(), falseBlock,
+ /*falseArguments=*/ArrayRef<Value>(), weights);
+
+ return success();
+}
+
+LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
+ if (!curFunction) {
+ return emitError(unknownLoc, "OpLabel must appear inside a function");
+ }
+
+ if (operands.size() != 1) {
+ return emitError(unknownLoc, "OpLabel should only have result <id>");
+ }
+
+ auto labelID = operands[0];
+ // We may have forward declared this block.
+ auto *block = getOrCreateBlock(labelID);
+ LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n");
+ // If we have seen this block, make sure it was just a forward declaration.
+ assert(block->empty() && "re-deserialize the same block!");
+
+ opBuilder.setInsertionPointToStart(block);
+ blockMap[labelID] = curBlock = block;
+
+ return success();
+}
+
+LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
+ if (!curBlock) {
+ return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
+ }
+
+ if (operands.size() < 2) {
+ return emitError(
+ unknownLoc,
+ "OpLoopMerge must specify merge target and selection control");
+ }
+
+ if (static_cast<uint32_t>(spirv::LoopControl::None) != operands[1]) {
+ return emitError(unknownLoc,
+ "unimplmented OpSelectionMerge selection control: ")
+ << operands[2];
+ }
+
+ auto *mergeBlock = getOrCreateBlock(operands[0]);
+
+ if (!blockMergeInfo.try_emplace(curBlock, mergeBlock).second) {
+ return emitError(
+ unknownLoc,
+ "a block cannot have more than one OpSelectionMerge instruction");
+ }
+
+ return success();
+}
+
+LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
+ if (!curBlock) {
+ return emitError(unknownLoc, "OpLoopMerge must appear in a block");
+ }
+
+ if (operands.size() < 3) {
+ return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
+ "continue target and loop control");
+ }
+
+ if (static_cast<uint32_t>(spirv::LoopControl::None) != operands[2]) {
+ return emitError(unknownLoc, "unimplmented OpLoopMerge loop control: ")
+ << operands[2];
+ }
+
+ auto *mergeBlock = getOrCreateBlock(operands[0]);
+ auto *continueBlock = getOrCreateBlock(operands[1]);
+
+ if (!blockMergeInfo.try_emplace(curBlock, mergeBlock, continueBlock).second) {
+ return emitError(
+ unknownLoc,
+ "a block cannot have more than one OpLoopMerge instruction");
+ }
+
+ return success();
+}
+
+LogicalResult Deserializer::processPhi(ArrayRef<uint32_t> operands) {
+ if (!curBlock) {
+ return emitError(unknownLoc, "OpPhi must appear in a block");
+ }
+
+ if (operands.size() < 4) {
+ return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
+ "and variable-parent pairs");
+ }
+
+ // Create a block argument for this OpPhi instruction.
+ Type blockArgType = getType(operands[0]);
+ BlockArgument blockArg = curBlock->addArgument(blockArgType);
+ valueMap[operands[1]] = blockArg;
+ LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg
+ << " id = " << operands[1] << " of type "
+ << blockArgType << '\n');
+
+ // For each (value, predecessor) pair, insert the value to the predecessor's
+ // blockPhiInfo entry so later we can fix the block argument there.
+ for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
+ uint32_t value = operands[i];
+ Block *predecessor = getOrCreateBlock(operands[i + 1]);
+ blockPhiInfo[predecessor].push_back(value);
+ LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
+ << " with arg id = " << value << '\n');
+ }
+
+ return success();
+}
+
+namespace {
+/// A class for putting all blocks in a structured selection/loop in a
+/// spv.selection/spv.loop op.
+class ControlFlowStructurizer {
+public:
+ /// Structurizes the loop at the given `headerBlock`.
+ ///
+ /// This method will create an spv.loop op in the `mergeBlock` and move all
+ /// blocks in the structured loop into the spv.loop's region. All branches to
+ /// the `headerBlock` will be redirected to the `mergeBlock`.
+ /// This method will also update `mergeInfo` by remapping all blocks inside to
+ /// the newly cloned ones inside structured control flow op's regions.
+ static LogicalResult structurize(Location loc, BlockMergeInfoMap &mergeInfo,
+ Block *headerBlock, Block *mergeBlock,
+ Block *continueBlock) {
+ return ControlFlowStructurizer(loc, mergeInfo, headerBlock, mergeBlock,
+ continueBlock)
+ .structurizeImpl();
+ }
+
+private:
+ ControlFlowStructurizer(Location loc, BlockMergeInfoMap &mergeInfo,
+ Block *header, Block *merge, Block *cont)
+ : location(loc), blockMergeInfo(mergeInfo), headerBlock(header),
+ mergeBlock(merge), continueBlock(cont) {}
+
+ /// Creates a new spv.selection op at the beginning of the `mergeBlock`.
+ spirv::SelectionOp createSelectionOp();
+
+ /// Creates a new spv.loop op at the beginning of the `mergeBlock`.
+ spirv::LoopOp createLoopOp();
+
+ /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
+ void collectBlocksInConstruct();
+
+ LogicalResult structurizeImpl();
+
+ Location location;
+
+ BlockMergeInfoMap &blockMergeInfo;
+
+ Block *headerBlock;
+ Block *mergeBlock;
+ Block *continueBlock; // nullptr for spv.selection
+
+ llvm::SetVector<Block *> constructBlocks;
+};
+} // namespace
+
+spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() {
+ // Create a builder and set the insertion point to the beginning of the
+ // merge block so that the newly created SelectionOp will be inserted there.
+ OpBuilder builder(&mergeBlock->front());
+
+ auto control = builder.getI32IntegerAttr(
+ static_cast<uint32_t>(spirv::SelectionControl::None));
+ auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
+ selectionOp.addMergeBlock();
+
+ return selectionOp;
+}
+
+spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
+ // Create a builder and set the insertion point to the beginning of the
+ // merge block so that the newly created LoopOp will be inserted there.
+ OpBuilder builder(&mergeBlock->front());
+
+ // TODO(antiagainst): handle loop control properly
+ auto loopOp = builder.create<spirv::LoopOp>(location);
+ loopOp.addEntryAndMergeBlock();
+
+ return loopOp;
+}
+
+void ControlFlowStructurizer::collectBlocksInConstruct() {
+ assert(constructBlocks.empty() && "expected empty constructBlocks");
+
+ // Put the header block in the work list first.
+ constructBlocks.insert(headerBlock);
+
+ // For each item in the work list, add its successors excluding the merge
+ // block.
+ for (unsigned i = 0; i < constructBlocks.size(); ++i) {
+ for (auto *successor : constructBlocks[i]->getSuccessors())
+ if (successor != mergeBlock)
+ constructBlocks.insert(successor);
+ }
+}
+
+LogicalResult ControlFlowStructurizer::structurizeImpl() {
+ Operation *op = nullptr;
+ bool isLoop = continueBlock != nullptr;
+ if (isLoop) {
+ if (auto loopOp = createLoopOp())
+ op = loopOp.getOperation();
+ } else {
+ if (auto selectionOp = createSelectionOp())
+ op = selectionOp.getOperation();
+ }
+ if (!op)
+ return failure();
+ Region &body = op->getRegion(0);
+
+ BlockAndValueMapping mapper;
+ // All references to the old merge block should be directed to the
+ // selection/loop merge block in the SelectionOp/LoopOp's region.
+ mapper.map(mergeBlock, &body.back());
+
+ collectBlocksInConstruct();
+
+ // We've identified all blocks belonging to the selection/loop's region. Now
+ // need to "move" them into the selection/loop. Instead of really moving the
+ // blocks, in the following we copy them and remap all values and branches.
+ // This is because:
+ // * Inserting a block into a region requires the block not in any region
+ // before. But selections/loops can nest so we can create selection/loop ops
+ // in a nested manner, which means some blocks may already be in a
+ // selection/loop region when to be moved again.
+ // * It's much trickier to fix up the branches into and out of the loop's
+ // region: we need to treat not-moved blocks and moved blocks differently:
+ // Not-moved blocks jumping to the loop header block need to jump to the
+ // merge point containing the new loop op but not the loop continue block's
+ // back edge. Moved blocks jumping out of the loop need to jump to the
+ // merge block inside the loop region but not other not-moved blocks.
+ // We cannot use replaceAllUsesWith clearly and it's harder to follow the
+ // logic.
+
+ // Create a corresponding block in the SelectionOp/LoopOp's region for each
+ // block in this loop construct.
+ OpBuilder builder(body);
+ for (auto *block : constructBlocks) {
+ // Create a block and insert it before the selection/loop merge block in the
+ // SelectionOp/LoopOp's region.
+ auto *newBlock = builder.createBlock(&body.back());
+ mapper.map(block, newBlock);
+ LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
+ << " from block " << block << "\n");
+ if (!isFnEntryBlock(block)) {
+ for (BlockArgument blockArg : block->getArguments()) {
+ auto newArg = newBlock->addArgument(blockArg->getType());
+ mapper.map(blockArg, newArg);
+ LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
+ << " to " << newArg << '\n');
+ }
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[cf] block " << block << " is a function entry block\n");
+ }
+
+ for (auto &op : *block)
+ newBlock->push_back(op.clone(mapper));
+ }
+
+ // Go through all ops and remap the operands.
+ auto remapOperands = [&](Operation *op) {
+ for (auto &operand : op->getOpOperands())
+ if (auto mappedOp = mapper.lookupOrNull(operand.get()))
+ operand.set(mappedOp);
+ for (auto &succOp : op->getBlockOperands())
+ if (auto mappedOp = mapper.lookupOrNull(succOp.get()))
+ succOp.set(mappedOp);
+ };
+ for (auto &block : body) {
+ block.walk(remapOperands);
+ }
+
+ // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
+ // the selection/loop construct into its region. Next we need to fix the
+ // connections between this new SelectionOp/LoopOp with existing blocks.
+
+ // All existing incoming branches should go to the merge block, where the
+ // SelectionOp/LoopOp resides right now.
+ headerBlock->replaceAllUsesWith(mergeBlock);
+
+ if (isLoop) {
+ // The loop selection/loop header block may have block arguments. Since now
+ // we place the selection/loop op inside the old merge block, we need to
+ // make sure the old merge block has the same block argument list.
+ assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
+ for (BlockArgument blockArg : headerBlock->getArguments()) {
+ mergeBlock->addArgument(blockArg->getType());
+ }
+
+ // If the loop header block has block arguments, make sure the spv.branch op
+ // matches.
+ SmallVector<Value, 4> blockArgs;
+ if (!headerBlock->args_empty())
+ blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
+
+ // The loop entry block should have a unconditional branch jumping to the
+ // loop header block.
+ builder.setInsertionPointToEnd(&body.front());
+ builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
+ ArrayRef<Value>(blockArgs));
+ }
+
+ // All the blocks cloned into the SelectionOp/LoopOp's region can now be
+ // cleaned up.
+ LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n");
+ // First we need to drop all operands' references inside all blocks. This is
+ // needed because we can have blocks referencing SSA values from one another.
+ for (auto *block : constructBlocks)
+ block->dropAllReferences();
+
+ // Then erase all old blocks.
+ for (auto *block : constructBlocks) {
+ // We've cloned all blocks belonging to this construct into the structured
+ // control flow op's region. Among these blocks, some may compose another
+ // selection/loop. If so, they will be recorded within blockMergeInfo.
+ // We need to update the pointers there to the newly remapped ones so we can
+ // continue structurizing them later.
+ // TODO(antiagainst): The asserts in the following assumes input SPIR-V blob
+ // forms correctly nested selection/loop constructs. We should relax this
+ // and support error cases better.
+ auto it = blockMergeInfo.find(block);
+ if (it != blockMergeInfo.end()) {
+ Block *newHeader = mapper.lookupOrNull(block);
+ assert(newHeader && "nested loop header block should be remapped!");
+
+ Block *newContinue = it->second.continueBlock;
+ if (newContinue) {
+ newContinue = mapper.lookupOrNull(newContinue);
+ assert(newContinue && "nested loop continue block should be remapped!");
+ }
+
+ Block *newMerge = it->second.mergeBlock;
+ if (Block *mappedTo = mapper.lookupOrNull(newMerge))
+ newMerge = mappedTo;
+
+ // The iterator should be erased before adding a new entry into
+ // blockMergeInfo to avoid iterator invalidation.
+ blockMergeInfo.erase(it);
+ blockMergeInfo.try_emplace(newHeader, newMerge, newContinue);
+ }
+
+ // The structured selection/loop's entry block does not have arguments.
+ // If the function's header block is also part of the structured control
+ // flow, we cannot just simply erase it because it may contain arguments
+ // matching the function signature and used by the cloned blocks.
+ if (isFnEntryBlock(block)) {
+ LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block
+ << " to only contain a spv.Branch op\n");
+ // Still keep the function entry block for the potential block arguments,
+ // but replace all ops inside with a branch to the merge block.
+ block->clear();
+ builder.setInsertionPointToEnd(block);
+ builder.create<spirv::BranchOp>(location, mergeBlock);
+ } else {
+ LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n");
+ block->erase();
+ }
+ }
+
+ LLVM_DEBUG(
+ llvm::dbgs() << "[cf] after structurizing construct with header block "
+ << headerBlock << ":\n"
+ << *op << '\n');
+
+ return success();
+}
+
+LogicalResult Deserializer::wireUpBlockArgument() {
+ LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n");
+
+ OpBuilder::InsertionGuard guard(opBuilder);
+
+ for (const auto &info : blockPhiInfo) {
+ Block *block = info.first;
+ const BlockPhiInfo &phiInfo = info.second;
+ LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
+ LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << '\n');
+
+ // Set insertion point to before this block's terminator early because we
+ // may materialize ops via getValue() call.
+ auto *op = block->getTerminator();
+ opBuilder.setInsertionPoint(op);
+
+ SmallVector<Value, 4> blockArgs;
+ blockArgs.reserve(phiInfo.size());
+ for (uint32_t valueId : phiInfo) {
+ if (Value value = getValue(valueId)) {
+ blockArgs.push_back(value);
+ LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value
+ << " id = " << valueId << '\n');
+ } else {
+ return emitError(unknownLoc, "OpPhi references undefined value!");
+ }
+ }
+
+ if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
+ // Replace the previous branch op with a new one with block arguments.
+ opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
+ blockArgs);
+ branchOp.erase();
+ } else {
+ return emitError(unknownLoc, "unimplemented terminator for Phi creation");
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n");
+ LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << '\n');
+ }
+ blockPhiInfo.clear();
+
+ LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n");
+ return success();
+}
+
+LogicalResult Deserializer::structurizeControlFlow() {
+ LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
+
+ while (!blockMergeInfo.empty()) {
+ Block *headerBlock = blockMergeInfo.begin()->first;
+ BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
+
+ LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n");
+ LLVM_DEBUG(headerBlock->print(llvm::dbgs()));
+
+ auto *mergeBlock = mergeInfo.mergeBlock;
+ assert(mergeBlock && "merge block cannot be nullptr");
+ if (!mergeBlock->args_empty())
+ return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
+ LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n");
+ LLVM_DEBUG(mergeBlock->print(llvm::dbgs()));
+
+ auto *continueBlock = mergeInfo.continueBlock;
+ if (continueBlock) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[cf] continue block " << continueBlock << ":\n");
+ LLVM_DEBUG(continueBlock->print(llvm::dbgs()));
+ }
+
+ // Erase this case before calling into structurizer, who will update
+ // blockMergeInfo.
+ blockMergeInfo.erase(blockMergeInfo.begin());
+ if (failed(ControlFlowStructurizer::structurize(unknownLoc, blockMergeInfo,
+ headerBlock, mergeBlock,
+ continueBlock)))
+ return failure();
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Instruction
+//===----------------------------------------------------------------------===//
+
+Value Deserializer::getValue(uint32_t id) {
+ if (auto constInfo = getConstant(id)) {
+ // Materialize a `spv.constant` op at every use site.
+ return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
+ constInfo->first);
+ }
+ if (auto varOp = getGlobalVariable(id)) {
+ auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
+ unknownLoc, varOp.type(),
+ opBuilder.getSymbolRefAttr(varOp.getOperation()));
+ return addressOfOp.pointer();
+ }
+ if (auto constOp = getSpecConstant(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, constOp.default_value().getType(),
+ opBuilder.getSymbolRefAttr(constOp.getOperation()));
+ return referenceOfOp.reference();
+ }
+ if (auto undef = getUndefType(id)) {
+ return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
+ }
+ return valueMap.lookup(id);
+}
+
+LogicalResult
+Deserializer::sliceInstruction(spirv::Opcode &opcode,
+ ArrayRef<uint32_t> &operands,
+ Optional<spirv::Opcode> expectedOpcode) {
+ auto binarySize = binary.size();
+ if (curOffset >= binarySize) {
+ return emitError(unknownLoc, "expected ")
+ << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
+ : "more")
+ << " instruction";
+ }
+
+ // For each instruction, get its word count from the first word to slice it
+ // from the stream properly, and then dispatch to the instruction handler.
+
+ uint32_t wordCount = binary[curOffset] >> 16;
+
+ if (wordCount == 0)
+ return emitError(unknownLoc, "word count cannot be zero");
+
+ uint32_t nextOffset = curOffset + wordCount;
+ if (nextOffset > binarySize)
+ return emitError(unknownLoc, "insufficient words for the last instruction");
+
+ opcode = extractOpcode(binary[curOffset]);
+ operands = binary.slice(curOffset + 1, wordCount - 1);
+ curOffset = nextOffset;
+ return success();
+}
+
+LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands,
+ bool deferInstructions) {
+ LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction "
+ << spirv::stringifyOpcode(opcode) << "\n");
+
+ // First dispatch all the instructions whose opcode does not correspond to
+ // those that have a direct mirror in the SPIR-V dialect
+ switch (opcode) {
+ case spirv::Opcode::OpBitcast:
+ return processBitcast(operands);
+ case spirv::Opcode::OpCapability:
+ return processCapability(operands);
+ case spirv::Opcode::OpExtension:
+ return processExtension(operands);
+ case spirv::Opcode::OpExtInst:
+ return processExtInst(operands);
+ case spirv::Opcode::OpExtInstImport:
+ return processExtInstImport(operands);
+ case spirv::Opcode::OpMemberName:
+ return processMemberName(operands);
+ case spirv::Opcode::OpMemoryModel:
+ return processMemoryModel(operands);
+ case spirv::Opcode::OpEntryPoint:
+ case spirv::Opcode::OpExecutionMode:
+ if (deferInstructions) {
+ deferredInstructions.emplace_back(opcode, operands);
+ return success();
+ }
+ break;
+ case spirv::Opcode::OpVariable:
+ if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
+ return processGlobalVariable(operands);
+ }
+ break;
+ case spirv::Opcode::OpName:
+ return processName(operands);
+ case spirv::Opcode::OpModuleProcessed:
+ case spirv::Opcode::OpString:
+ case spirv::Opcode::OpSource:
+ case spirv::Opcode::OpSourceContinued:
+ case spirv::Opcode::OpSourceExtension:
+ // TODO: This is debug information embedded in the binary which should be
+ // translated into the spv.module.
+ return success();
+ case spirv::Opcode::OpTypeVoid:
+ case spirv::Opcode::OpTypeBool:
+ case spirv::Opcode::OpTypeInt:
+ case spirv::Opcode::OpTypeFloat:
+ case spirv::Opcode::OpTypeVector:
+ case spirv::Opcode::OpTypeArray:
+ case spirv::Opcode::OpTypeFunction:
+ case spirv::Opcode::OpTypeRuntimeArray:
+ case spirv::Opcode::OpTypeStruct:
+ case spirv::Opcode::OpTypePointer:
+ return processType(opcode, operands);
+ case spirv::Opcode::OpConstant:
+ return processConstant(operands, /*isSpec=*/false);
+ case spirv::Opcode::OpSpecConstant:
+ return processConstant(operands, /*isSpec=*/true);
+ case spirv::Opcode::OpConstantComposite:
+ return processConstantComposite(operands);
+ case spirv::Opcode::OpConstantTrue:
+ return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
+ case spirv::Opcode::OpSpecConstantTrue:
+ return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
+ case spirv::Opcode::OpConstantFalse:
+ return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
+ case spirv::Opcode::OpSpecConstantFalse:
+ return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
+ case spirv::Opcode::OpConstantNull:
+ return processConstantNull(operands);
+ case spirv::Opcode::OpDecorate:
+ return processDecoration(operands);
+ case spirv::Opcode::OpMemberDecorate:
+ return processMemberDecoration(operands);
+ case spirv::Opcode::OpFunction:
+ return processFunction(operands);
+ case spirv::Opcode::OpLabel:
+ return processLabel(operands);
+ case spirv::Opcode::OpBranch:
+ return processBranch(operands);
+ case spirv::Opcode::OpBranchConditional:
+ return processBranchConditional(operands);
+ case spirv::Opcode::OpSelectionMerge:
+ return processSelectionMerge(operands);
+ case spirv::Opcode::OpLoopMerge:
+ return processLoopMerge(operands);
+ case spirv::Opcode::OpPhi:
+ return processPhi(operands);
+ case spirv::Opcode::OpUndef:
+ return processUndef(operands);
+ default:
+ break;
+ }
+ return dispatchToAutogenDeserialization(opcode, operands);
+}
+
+LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpUndef instruction must have two operands");
+ }
+ auto type = getType(operands[0]);
+ if (!type) {
+ return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
+ }
+ undefMap[operands[1]] = type;
+ return success();
+}
+
+// TODO(b/130356985): This method is copied from the auto-generated
+// deserialization function for OpBitcast instruction. This is to avoid
+// generating a Bitcast operations for cast from signed integer to unsigned
+// integer and viceversa. MLIR doesn't have native support for this so they both
+// end up mapping to the same type right now which is illegal according to
+// OpBitcast semantics (and enforced by the SPIR-V dialect).
+LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) {
+ SmallVector<Type, 1> resultTypes;
+ size_t wordIndex = 0;
+ (void)wordIndex;
+ uint32_t valueID = 0;
+ (void)valueID;
+ {
+ if (wordIndex >= words.size()) {
+ return emitError(
+ unknownLoc,
+ "expected result type <id> while deserializing spirv::BitcastOp");
+ }
+ auto ty = getType(words[wordIndex]);
+ if (!ty) {
+ return emitError(unknownLoc, "unknown type result <id> : ")
+ << words[wordIndex];
+ }
+ resultTypes.push_back(ty);
+ wordIndex++;
+ if (wordIndex >= words.size()) {
+ return emitError(
+ unknownLoc,
+ "expected result <id> while deserializing spirv::BitcastOp");
+ }
+ }
+ valueID = words[wordIndex++];
+ SmallVector<Value, 4> operands;
+ SmallVector<NamedAttribute, 4> attributes;
+ if (wordIndex < words.size()) {
+ auto arg = getValue(words[wordIndex]);
+ if (!arg) {
+ return emitError(unknownLoc, "unknown result <id> : ")
+ << words[wordIndex];
+ }
+ operands.push_back(arg);
+ wordIndex++;
+ }
+ if (wordIndex != words.size()) {
+ return emitError(unknownLoc,
+ "found more operands than expected when deserializing "
+ "spirv::BitcastOp, only ")
+ << wordIndex << " of " << words.size() << " processed";
+ }
+ if (resultTypes[0] == operands[0]->getType() &&
+ resultTypes[0].isa<IntegerType>()) {
+ // TODO(b/130356985): This check is added to ignore error in Op verification
+ // due to both signed and unsigned integers mapping to the same
+ // type. Without this check this method is same as what is auto-generated.
+ valueMap[valueID] = operands[0];
+ return success();
+ }
+
+ auto op = opBuilder.create<spirv::BitcastOp>(unknownLoc, resultTypes,
+ operands, attributes);
+ (void)op;
+ valueMap[valueID] = op.getResult();
+
+ if (decorations.count(valueID)) {
+ auto attrs = decorations[valueID].getAttrs();
+ attributes.append(attrs.begin(), attrs.end());
+ }
+ return success();
+}
+
+LogicalResult Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 4) {
+ return emitError(unknownLoc,
+ "OpExtInst must have at least 4 operands, result type "
+ "<id>, result <id>, set <id> and instruction opcode");
+ }
+ if (!extendedInstSets.count(operands[2])) {
+ return emitError(unknownLoc, "undefined set <id> in OpExtInst");
+ }
+ SmallVector<uint32_t, 4> slicedOperands;
+ slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
+ slicedOperands.append(std::next(operands.begin(), 4), operands.end());
+ return dispatchToExtensionSetAutogenDeserialization(
+ extendedInstSets[operands[2]], operands[3], slicedOperands);
+}
+
+namespace {
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
+ unsigned wordIndex = 0;
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc,
+ "missing Execution Model specification in OpEntryPoint");
+ }
+ auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc, "missing <id> in OpEntryPoint");
+ }
+ // Get the function <id>
+ auto fnID = words[wordIndex++];
+ // Get the function name
+ auto fnName = decodeStringLiteral(words, wordIndex);
+ // Verify that the function <id> matches the fnName
+ auto parsedFunc = getFunction(fnID);
+ if (!parsedFunc) {
+ return emitError(unknownLoc, "no function matching <id> ") << fnID;
+ }
+ if (parsedFunc.getName() != fnName) {
+ return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
+ "and OpFunction with <id> ")
+ << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
+ }
+ SmallVector<Attribute, 4> interface;
+ while (wordIndex < words.size()) {
+ auto arg = getGlobalVariable(words[wordIndex]);
+ if (!arg) {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << words[wordIndex] << " while decoding OpEntryPoint";
+ }
+ interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
+ wordIndex++;
+ }
+ opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+ opBuilder.getSymbolRefAttr(fnName),
+ opBuilder.getArrayAttr(interface));
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
+ unsigned wordIndex = 0;
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc,
+ "missing function result <id> in OpExecutionMode");
+ }
+ // Get the function <id> to get the name of the function
+ auto fnID = words[wordIndex++];
+ auto fn = getFunction(fnID);
+ if (!fn) {
+ return emitError(unknownLoc, "no function matching <id> ") << fnID;
+ }
+ // Get the Execution mode
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
+ }
+ auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+
+ // Get the values
+ SmallVector<Attribute, 4> attrListElems;
+ while (wordIndex < words.size()) {
+ attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
+ }
+ auto values = opBuilder.getArrayAttr(attrListElems);
+ opBuilder.create<spirv::ExecutionModeOp>(
+ unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc,
+ "OpControlBarrier must have execution scope <id>, memory scope <id> "
+ "and memory semantics <id>");
+ }
+
+ SmallVector<IntegerAttr, 3> argAttrs;
+ for (auto operand : operands) {
+ auto argAttr = getConstantInt(operand);
+ if (!argAttr) {
+ return emitError(unknownLoc,
+ "expected 32-bit integer constant from <id> ")
+ << operand << " for OpControlBarrier";
+ }
+ argAttrs.push_back(argAttr);
+ }
+
+ opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0],
+ argAttrs[1], argAttrs[2]);
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 3) {
+ return emitError(unknownLoc,
+ "OpFunctionCall must have at least 3 operands");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+ auto functionID = operands[2];
+
+ auto functionName = getFunctionSymbol(functionID);
+
+ SmallVector<Value, 4> arguments;
+ for (auto operand : llvm::drop_begin(operands, 3)) {
+ auto value = getValue(operand);
+ if (!value) {
+ return emitError(unknownLoc, "unknown <id> ")
+ << operand << " used by OpFunctionCall";
+ }
+ arguments.push_back(value);
+ }
+
+ SmallVector<Type, 1> resultTypes;
+ if (!isVoidType(resultType)) {
+ resultTypes.push_back(resultType);
+ }
+
+ auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
+ unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName),
+ arguments);
+
+ if (!resultTypes.empty()) {
+ valueMap[resultID] = opFunctionCall.getResult(0);
+ }
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
+ "and memory semantics <id>");
+ }
+
+ SmallVector<IntegerAttr, 2> argAttrs;
+ for (auto operand : operands) {
+ auto argAttr = getConstantInt(operand);
+ if (!argAttr) {
+ return emitError(unknownLoc,
+ "expected 32-bit integer constant from <id> ")
+ << operand << " for OpMemoryBarrier";
+ }
+ argAttrs.push_back(argAttr);
+ }
+
+ opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0],
+ argAttrs[1]);
+ return success();
+}
+
+// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
+// various Deserializer::processOp<...>() specializations.
+#define GET_DESERIALIZATION_FNS
+#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
+} // namespace
+
+Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
+ MLIRContext *context) {
+ Deserializer deserializer(binary, context);
+
+ if (failed(deserializer.deserialize()))
+ return llvm::None;
+
+ return deserializer.collect();
+}
OpenPOWER on IntegriCloud