summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp110
1 files changed, 104 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 1aad7173dc6..a3d71eda5d9 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -90,9 +90,20 @@ private:
/// them to their handler method accordingly.
LogicalResult processFunction(ArrayRef<uint32_t> operands);
+ /// Process the OpVariable instructions at current `offset` into `binary`. It
+ /// is expected that this method is used for variables that are to be defined
+ /// at module scope and will be deserialized into a spv.globalVariable
+ /// instruction.
+ LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
+
/// Get the FuncOp associated with a result <id> of OpFunction.
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
+ /// Get the global variable associated with a result <id> of OpVariable
+ spirv::GlobalVariableOp getVariable(uint32_t id) {
+ return globalVariableMap.lookup(id);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@@ -138,7 +149,15 @@ private:
//===--------------------------------------------------------------------===//
/// Get the Value associated with a result <id>.
- Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+ Value *getValue(uint32_t id) {
+ if (auto varOp = getVariable(id)) {
+ auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
+ unknownLoc, varOp.type(),
+ opBuilder.getSymbolRefAttr(varOp.getOperation()));
+ return addressOfOp.pointer();
+ }
+ return valueMap.lookup(id);
+ }
/// Slices the first instruction out of `binary` and returns its opcode and
/// operands via `opcode` and `operands` respectively. Returns failure if
@@ -198,6 +217,9 @@ private:
// Result <id> to function mapping.
DenseMap<uint32_t, FuncOp> funcMap;
+ // Result <id> to variable mapping;
+ DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+
// Result <id> to value mapping.
DenseMap<uint32_t, Value *> valueMap;
@@ -452,6 +474,76 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
+ unsigned wordIndex = 0;
+ if (operands.size() < 3) {
+ return emitError(
+ unknownLoc,
+ "OpVariable needs at least 3 operands, type, <id> and storage class");
+ }
+
+ // Result Type.
+ auto type = getType(operands[wordIndex]);
+ if (!type) {
+ return emitError(unknownLoc, "unknown result type <id> : ")
+ << operands[wordIndex];
+ }
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return emitError(unknownLoc,
+ "expected a result type <id> to be a spv.ptr, found : ")
+ << type;
+ }
+ wordIndex++;
+
+ // Result <id>.
+ auto variableID = operands[wordIndex];
+ auto variableName = nameMap.lookup(variableID).str();
+ if (variableName.empty()) {
+ variableName = "spirv_var_" + std::to_string(variableID);
+ }
+ wordIndex++;
+
+ // Storage class.
+ auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
+ if (ptrType.getStorageClass() != storageClass) {
+ return emitError(unknownLoc, "mismatch in storage class of pointer type ")
+ << type << " and that specified in OpVariable instruction : "
+ << stringifyStorageClass(storageClass);
+ }
+ wordIndex++;
+
+ // Initializer.
+ SymbolRefAttr initializer = nullptr;
+ if (wordIndex < operands.size()) {
+ auto initializerOp = getVariable(operands[wordIndex]);
+ if (!initializerOp) {
+ return emitError(unknownLoc, "unknown <id> ")
+ << operands[wordIndex] << "used as initializer";
+ }
+ wordIndex++;
+ initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
+ }
+ if (wordIndex != operands.size()) {
+ return emitError(unknownLoc,
+ "found more operands than expected when deserializing "
+ "OpVariable instruction, only ")
+ << wordIndex << " of " << operands.size() << " processed";
+ }
+ auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
+ unknownLoc, opBuilder.getTypeAttr(type),
+ opBuilder.getStringAttr(variableName), initializer);
+
+ // Decorations.
+ if (decorations.count(variableID)) {
+ for (auto attr : decorations[variableID].getAttrs()) {
+ varOp.setAttr(attr.first, attr.second);
+ }
+ }
+ globalVariableMap[variableID] = varOp;
+ return success();
+}
+
LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc, "OpName needs at least 2 operands");
@@ -887,6 +979,11 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return success();
}
break;
+ case spirv::Opcode::OpVariable:
+ if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
+ return processGlobalVariable(operands);
+ }
+ break;
case spirv::Opcode::OpName:
return processName(operands);
case spirv::Opcode::OpTypeVoid:
@@ -954,18 +1051,19 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
"and OpFunction with <id> ")
<< fnID << ": " << fnName << " vs. " << parsedFunc.getName();
}
- SmallVector<Value *, 4> interface;
+ SmallVector<Attribute, 4> interface;
while (wordIndex < words.size()) {
- auto arg = getValue(words[wordIndex]);
+ auto arg = getVariable(words[wordIndex]);
if (!arg) {
return emitError(unknownLoc, "undefined result <id> ")
<< words[wordIndex] << " while decoding OpEntryPoint";
}
- interface.push_back(arg);
+ interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
wordIndex++;
}
- opBuilder.create<spirv::EntryPointOp>(
- unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface);
+ opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+ opBuilder.getSymbolRefAttr(fnName),
+ opBuilder.getArrayAttr(interface));
return success();
}
OpenPOWER on IntegriCloud