diff options
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | 110 |
1 files changed, 104 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 1aad7173dc6..a3d71eda5d9 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -90,9 +90,20 @@ private: /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef<uint32_t> operands); + /// Process the OpVariable instructions at current `offset` into `binary`. It + /// is expected that this method is used for variables that are to be defined + /// at module scope and will be deserialized into a spv.globalVariable + /// instruction. + LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands); + /// Get the FuncOp associated with a result <id> of OpFunction. FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } + /// Get the global variable associated with a result <id> of OpVariable + spirv::GlobalVariableOp getVariable(uint32_t id) { + return globalVariableMap.lookup(id); + } + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// @@ -138,7 +149,15 @@ private: //===--------------------------------------------------------------------===// /// Get the Value associated with a result <id>. - Value *getValue(uint32_t id) { return valueMap.lookup(id); } + Value *getValue(uint32_t id) { + if (auto varOp = getVariable(id)) { + auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( + unknownLoc, varOp.type(), + opBuilder.getSymbolRefAttr(varOp.getOperation())); + return addressOfOp.pointer(); + } + return valueMap.lookup(id); + } /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if @@ -198,6 +217,9 @@ private: // Result <id> to function mapping. DenseMap<uint32_t, FuncOp> funcMap; + // Result <id> to variable mapping; + DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; + // Result <id> to value mapping. DenseMap<uint32_t, Value *> valueMap; @@ -452,6 +474,76 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { return success(); } +LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { + unsigned wordIndex = 0; + if (operands.size() < 3) { + return emitError( + unknownLoc, + "OpVariable needs at least 3 operands, type, <id> and storage class"); + } + + // Result Type. + auto type = getType(operands[wordIndex]); + if (!type) { + return emitError(unknownLoc, "unknown result type <id> : ") + << operands[wordIndex]; + } + auto ptrType = type.dyn_cast<spirv::PointerType>(); + if (!ptrType) { + return emitError(unknownLoc, + "expected a result type <id> to be a spv.ptr, found : ") + << type; + } + wordIndex++; + + // Result <id>. + auto variableID = operands[wordIndex]; + auto variableName = nameMap.lookup(variableID).str(); + if (variableName.empty()) { + variableName = "spirv_var_" + std::to_string(variableID); + } + wordIndex++; + + // Storage class. + auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); + if (ptrType.getStorageClass() != storageClass) { + return emitError(unknownLoc, "mismatch in storage class of pointer type ") + << type << " and that specified in OpVariable instruction : " + << stringifyStorageClass(storageClass); + } + wordIndex++; + + // Initializer. + SymbolRefAttr initializer = nullptr; + if (wordIndex < operands.size()) { + auto initializerOp = getVariable(operands[wordIndex]); + if (!initializerOp) { + return emitError(unknownLoc, "unknown <id> ") + << operands[wordIndex] << "used as initializer"; + } + wordIndex++; + initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); + } + if (wordIndex != operands.size()) { + return emitError(unknownLoc, + "found more operands than expected when deserializing " + "OpVariable instruction, only ") + << wordIndex << " of " << operands.size() << " processed"; + } + auto varOp = opBuilder.create<spirv::GlobalVariableOp>( + unknownLoc, opBuilder.getTypeAttr(type), + opBuilder.getStringAttr(variableName), initializer); + + // Decorations. + if (decorations.count(variableID)) { + for (auto attr : decorations[variableID].getAttrs()) { + varOp.setAttr(attr.first, attr.second); + } + } + globalVariableMap[variableID] = varOp; + return success(); +} + LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpName needs at least 2 operands"); @@ -887,6 +979,11 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return success(); } break; + case spirv::Opcode::OpVariable: + if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { + return processGlobalVariable(operands); + } + break; case spirv::Opcode::OpName: return processName(operands); case spirv::Opcode::OpTypeVoid: @@ -954,18 +1051,19 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { "and OpFunction with <id> ") << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); } - SmallVector<Value *, 4> interface; + SmallVector<Attribute, 4> interface; while (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); + auto arg = getVariable(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "undefined result <id> ") << words[wordIndex] << " while decoding OpEntryPoint"; } - interface.push_back(arg); + interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } - opBuilder.create<spirv::EntryPointOp>( - unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface); + opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model, + opBuilder.getSymbolRefAttr(fnName), + opBuilder.getArrayAttr(interface)); return success(); } |