summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp23
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp347
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp110
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp120
4 files changed, 485 insertions, 115 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 53a40dfa365..035de4f815d 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -136,26 +136,26 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
signatureConverter, newFuncOp))) {
return failure();
}
- // Create spv.Variable ops for each of the arguments. These need to be bound
- // by the runtime. For now use descriptor_set 0, and arg number as the binding
- // number.
+ // Create spv.globalVariable ops for each of the arguments. These need to be
+ // bound by the runtime. For now use descriptor_set 0, and arg number as the
+ // binding number.
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return funcOp.emitError("expected op to be within a spv.module");
}
OpBuilder builder(module.getOperation()->getRegion(0));
- SmallVector<Value *, 4> interface;
+ SmallVector<Attribute, 4> interface;
for (auto &convertedArgType :
llvm::enumerate(signatureConverter.getConvertedTypes())) {
- auto variableOp = builder.create<spirv::VariableOp>(
- funcOp.getLoc(), convertedArgType.value(),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
- llvm::None);
+ std::string varName = funcOp.getName().str() + "_arg_" +
+ std::to_string(convertedArgType.index());
+ auto variableOp = builder.create<spirv::GlobalVariableOp>(
+ funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
+ builder.getStringAttr(varName), nullptr);
variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
variableOp.setAttr("binding",
builder.getI32IntegerAttr(convertedArgType.index()));
- interface.push_back(variableOp.getResult());
+ interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
@@ -164,7 +164,8 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
funcOp.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
- builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+ builder.getSymbolRefAttr(newFuncOp.getName()),
+ builder.getArrayAttr(interface));
return success();
}
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 4bea441c366..9947c0254a9 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -32,11 +32,15 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
+static constexpr const char kInitializerAttrName[] = "initializer";
+static constexpr const char kInterfaceAttrName[] = "interface";
static constexpr const char kIsSpecConstName[] = "is_spec_const";
+static constexpr const char kTypeAttrName[] = "type";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
-static constexpr const char kFnNameAttrName[] = "fn";
+static constexpr const char kVariableAttrName[] = "variable";
//===----------------------------------------------------------------------===//
// Common utility functions
@@ -239,6 +243,71 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
printer->printOptionalAttrDict(op->getAttrs());
}
+static ParseResult parseVariableDecorations(OpAsmParser *parser,
+ OperationState *state) {
+ auto builtInName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
+ if (succeeded(parser->parseOptionalKeyword("bind"))) {
+ Attribute set, binding;
+ // Parse optional descriptor binding
+ auto descriptorSetName = convertToSnakeCase(
+ stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+ Type i32Type = parser->getBuilder().getIntegerType(32);
+ if (parser->parseLParen() ||
+ parser->parseAttribute(set, i32Type, descriptorSetName,
+ state->attributes) ||
+ parser->parseComma() ||
+ parser->parseAttribute(binding, i32Type, bindingName,
+ state->attributes) ||
+ parser->parseRParen()) {
+ return failure();
+ }
+ } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
+ StringAttr builtIn;
+ if (parser->parseLParen() ||
+ parser->parseAttribute(builtIn, Type(), builtInName,
+ state->attributes) ||
+ parser->parseRParen()) {
+ return failure();
+ }
+ }
+
+ // Parse other attributes
+ if (parser->parseOptionalAttributeDict(state->attributes))
+ return failure();
+
+ return success();
+}
+
+static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
+ SmallVectorImpl<StringRef> &elidedAttrs) {
+ // Print optional descriptor binding
+ auto descriptorSetName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+ auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
+ auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
+ if (descriptorSet && binding) {
+ elidedAttrs.push_back(descriptorSetName);
+ elidedAttrs.push_back(bindingName);
+ *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+ << ")";
+ }
+
+ // Print BuiltIn attribute if present
+ auto builtInName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
+ if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
+ *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
+ elidedAttrs.push_back(builtInName);
+ }
+
+ printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+}
+
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
@@ -363,6 +432,53 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
}
//===----------------------------------------------------------------------===//
+// spv._address_of
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAddressOfOp(OpAsmParser *parser,
+ OperationState *state) {
+ SymbolRefAttr varRefAttr;
+ Type type;
+ if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName,
+ state->attributes) ||
+ parser->parseColonType(type)) {
+ return failure();
+ }
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return parser->emitError(parser->getCurrentLocation(),
+ "expected spv.ptr type");
+ }
+ state->addTypes(ptrType);
+ return success();
+}
+
+static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
+ SmallVector<StringRef, 4> elidedAttrs;
+ *printer << spirv::AddressOfOp::getOperationName();
+
+ // Print symbol name.
+ *printer << " @" << addressOfOp.variable();
+
+ // Print the type.
+ *printer << " : " << addressOfOp.pointer();
+}
+
+static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
+ auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
+ auto varOp =
+ moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
+ if (!varOp) {
+ return addressOfOp.emitError("expected spv.globalVariable symbol");
+ }
+ if (addressOfOp.pointer()->getType() != varOp.type()) {
+ return addressOfOp.emitError(
+ "mismatch in result type and type of global variable referenced");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
@@ -541,18 +657,28 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser,
SmallVector<OpAsmParser::OperandType, 0> identifiers;
SmallVector<Type, 0> idTypes;
- Attribute fn;
- auto loc = parser->getCurrentLocation();
-
+ SymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
- parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
- parser->parseTrailingOperandList(identifiers) ||
- parser->parseOptionalColonTypeList(idTypes) ||
- parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
+ parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) {
return failure();
}
- if (!fn.isa<SymbolRefAttr>()) {
- return parser->emitError(loc, "expected symbol reference attribute");
+
+ if (!parser->parseOptionalComma()) {
+ // Parse the interface variables
+ SmallVector<Attribute, 4> interfaceVars;
+ do {
+ // The name of the interface variable attribute isnt important
+ auto attrName = "var_symbol";
+ SymbolRefAttr var;
+ SmallVector<NamedAttribute, 1> attrs;
+ if (parser->parseAttribute(var, Type(), attrName, attrs)) {
+ return failure();
+ }
+ interfaceVars.push_back(var);
+ } while (!parser->parseOptionalComma());
+ state->attributes.push_back(
+ {parser->getBuilder().getIdentifier(kInterfaceAttrName),
+ parser->getBuilder().getArrayAttr(interfaceVars)});
}
return success();
}
@@ -561,27 +687,16 @@ static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) {
*printer << spirv::EntryPointOp::getOperationName() << " \""
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
<< entryPointOp.fn();
- if (!entryPointOp.getNumOperands()) {
- return;
+ if (auto interface = entryPointOp.interface()) {
+ *printer << ", ";
+ mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(),
+ [&](Attribute a) { printer->printAttribute(a); });
}
- *printer << ", ";
- mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
- [&](Value *a) { printer->printOperand(a); });
- *printer << " : ";
- mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
- [&](const Value *a) { *printer << a->getType(); });
}
static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
- // Verify that all the interface ops are created from VariableOp
- for (auto interface : entryPointOp.interface()) {
- if (!llvm::isa_and_nonnull<spirv::VariableOp>(interface->getDefiningOp())) {
- return entryPointOp.emitOpError("interface operands to entry point must "
- "be generated from a variable op");
- }
- // TODO: Before version 1.4 the variables can only have storage_class of
- // Input or Output. That needs to be verified.
- }
+ // Checks for fn and interface symbol reference are done in spirv::ModuleOp
+ // verification.
return success();
}
@@ -628,6 +743,95 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
}
//===----------------------------------------------------------------------===//
+// spv.globalVariable
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
+ OperationState *state) {
+ // Parse variable type.
+ TypeAttr typeAttr;
+ auto loc = parser->getCurrentLocation();
+ if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName,
+ state->attributes)) {
+ return failure();
+ }
+ auto ptrType = typeAttr.getValue().dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return parser->emitError(loc, "expected spv.ptr type");
+ }
+
+ // Parse variable name.
+ StringAttr nameAttr;
+ if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+ state->attributes)) {
+ return failure();
+ }
+
+ // Parse optional initializer
+ if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) {
+ SymbolRefAttr initSymbol;
+ if (parser->parseLParen() ||
+ parser->parseAttribute(initSymbol, Type(), kInitializerAttrName,
+ state->attributes) ||
+ parser->parseRParen())
+ return failure();
+ }
+
+ if (parseVariableDecorations(parser, state)) {
+ return failure();
+ }
+
+ return success();
+}
+
+static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
+ auto *op = varOp.getOperation();
+ SmallVector<StringRef, 4> elidedAttrs{
+ spirv::attributeName<spirv::StorageClass>()};
+ *printer << spirv::GlobalVariableOp::getOperationName();
+
+ // Print variable type.
+ *printer << " " << varOp.type();
+ elidedAttrs.push_back(kTypeAttrName);
+
+ // Print variable name.
+ *printer << " @" << varOp.sym_name();
+ elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+
+ // Print optional initializer
+ if (auto initializer = varOp.initializer()) {
+ *printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
+ << ")";
+ elidedAttrs.push_back(kInitializerAttrName);
+ }
+ printVariableDecorations(op, printer, elidedAttrs);
+}
+
+static LogicalResult verify(spirv::GlobalVariableOp varOp) {
+ // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
+ // object. It cannot be Generic. It must be the same as the Storage Class
+ // operand of the Result Type."
+ if (varOp.storageClass() == spirv::StorageClass::Generic)
+ return varOp.emitOpError("storage class cannot be 'Generic'");
+
+ if (auto initializer =
+ varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+ // Get the module
+ auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
+ // TODO: Currently only variable initialization with other variables is
+ // supported. They could be constants as well, but this needs module-level
+ // constants to have symbol name as well.
+ if (!moduleOp.lookupSymbol<spirv::GlobalVariableOp>(
+ initializer.getValue())) {
+ return varOp.emitOpError(
+ "initializer must be result of a spv.globalVariable op");
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//
@@ -773,13 +977,33 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
for (auto &op : body) {
if (op.getDialect() == dialect) {
// For EntryPoint op, check that the function and execution model is not
- // duplicated in EntryPointOps
+ // duplicated in EntryPointOps. Also verify that the interface specified
+ // comes from globalVariables here to make this check cheaper.
if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) {
auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
if (!funcOp) {
return entryPointOp.emitError("function '")
<< entryPointOp.fn() << "' not found in 'spv.module'";
}
+ if (auto interface = entryPointOp.interface()) {
+ for (auto varRef : interface.getValue().getValue()) {
+ auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+ if (!varSymRef) {
+ return entryPointOp.emitError(
+ "expected symbol reference for interface "
+ "specification instead of '")
+ << varRef;
+ }
+ auto variableOp =
+ table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
+ if (!variableOp) {
+ return entryPointOp.emitError("expected spv.globalVariable "
+ "symbol reference instead of'")
+ << varSymRef << "'";
+ }
+ }
+ }
+
auto key = std::pair<FuncOp, spirv::ExecutionModel>(
funcOp, entryPointOp.execution_model());
auto entryPtIt = entryPoints.find(key);
@@ -898,42 +1122,9 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
return failure();
}
- auto builtInName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
- if (succeeded(parser->parseOptionalKeyword("bind"))) {
- Attribute set, binding;
- // Parse optional descriptor binding
- auto descriptorSetName = convertToSnakeCase(
- stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
- Type i32Type = parser->getBuilder().getIntegerType(32);
- if (parser->parseLParen() ||
- parser->parseAttribute(set, i32Type, descriptorSetName,
- state->attributes) ||
- parser->parseComma() ||
- parser->parseAttribute(binding, i32Type, bindingName,
- state->attributes) ||
- parser->parseRParen()) {
- return failure();
- }
- } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
- Attribute builtIn;
- if (parser->parseLParen() ||
- parser->parseAttribute(builtIn, Type(), builtInName,
- state->attributes) ||
- parser->parseRParen()) {
- return failure();
- }
- if (!builtIn.isa<StringAttr>()) {
- return parser->emitError(parser->getCurrentLocation(),
- "expected string value for built_in attribute");
- }
- }
-
- // Parse other attributes
- if (parser->parseOptionalAttributeDict(state->attributes))
+ if (parseVariableDecorations(parser, state)) {
return failure();
+ }
// Parse result pointer type
Type type;
@@ -976,29 +1167,8 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
*printer << ")";
}
- // Print optional descriptor binding
- auto descriptorSetName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
- auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
- auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
- if (descriptorSet && binding) {
- elidedAttrs.push_back(descriptorSetName);
- elidedAttrs.push_back(bindingName);
- *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
- << ")";
- }
-
- // Print BuiltIn attribute if present
- auto builtInName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
- if (auto builtin = varOp.getAttrOfType<StringAttr>(builtInName)) {
- *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
- elidedAttrs.push_back(builtInName);
- }
+ printVariableDecorations(op, printer, elidedAttrs);
- printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
*printer << " : " << varOp.getType();
}
@@ -1006,8 +1176,11 @@ static LogicalResult verify(spirv::VariableOp varOp) {
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
- if (varOp.storage_class() == spirv::StorageClass::Generic)
- return varOp.emitOpError("storage class cannot be 'Generic'");
+ if (varOp.storage_class() != spirv::StorageClass::Function) {
+ return varOp.emitOpError(
+ "can only be used to model function-level variables. Use "
+ "spv.globalVariable for module-level variables.");
+ }
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
if (varOp.storage_class() != pointerType.getStorageClass())
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 1aad7173dc6..a3d71eda5d9 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -90,9 +90,20 @@ private:
/// them to their handler method accordingly.
LogicalResult processFunction(ArrayRef<uint32_t> operands);
+ /// Process the OpVariable instructions at current `offset` into `binary`. It
+ /// is expected that this method is used for variables that are to be defined
+ /// at module scope and will be deserialized into a spv.globalVariable
+ /// instruction.
+ LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
+
/// Get the FuncOp associated with a result <id> of OpFunction.
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
+ /// Get the global variable associated with a result <id> of OpVariable
+ spirv::GlobalVariableOp getVariable(uint32_t id) {
+ return globalVariableMap.lookup(id);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@@ -138,7 +149,15 @@ private:
//===--------------------------------------------------------------------===//
/// Get the Value associated with a result <id>.
- Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+ Value *getValue(uint32_t id) {
+ if (auto varOp = getVariable(id)) {
+ auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
+ unknownLoc, varOp.type(),
+ opBuilder.getSymbolRefAttr(varOp.getOperation()));
+ return addressOfOp.pointer();
+ }
+ return valueMap.lookup(id);
+ }
/// Slices the first instruction out of `binary` and returns its opcode and
/// operands via `opcode` and `operands` respectively. Returns failure if
@@ -198,6 +217,9 @@ private:
// Result <id> to function mapping.
DenseMap<uint32_t, FuncOp> funcMap;
+ // Result <id> to variable mapping;
+ DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+
// Result <id> to value mapping.
DenseMap<uint32_t, Value *> valueMap;
@@ -452,6 +474,76 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
+ unsigned wordIndex = 0;
+ if (operands.size() < 3) {
+ return emitError(
+ unknownLoc,
+ "OpVariable needs at least 3 operands, type, <id> and storage class");
+ }
+
+ // Result Type.
+ auto type = getType(operands[wordIndex]);
+ if (!type) {
+ return emitError(unknownLoc, "unknown result type <id> : ")
+ << operands[wordIndex];
+ }
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return emitError(unknownLoc,
+ "expected a result type <id> to be a spv.ptr, found : ")
+ << type;
+ }
+ wordIndex++;
+
+ // Result <id>.
+ auto variableID = operands[wordIndex];
+ auto variableName = nameMap.lookup(variableID).str();
+ if (variableName.empty()) {
+ variableName = "spirv_var_" + std::to_string(variableID);
+ }
+ wordIndex++;
+
+ // Storage class.
+ auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
+ if (ptrType.getStorageClass() != storageClass) {
+ return emitError(unknownLoc, "mismatch in storage class of pointer type ")
+ << type << " and that specified in OpVariable instruction : "
+ << stringifyStorageClass(storageClass);
+ }
+ wordIndex++;
+
+ // Initializer.
+ SymbolRefAttr initializer = nullptr;
+ if (wordIndex < operands.size()) {
+ auto initializerOp = getVariable(operands[wordIndex]);
+ if (!initializerOp) {
+ return emitError(unknownLoc, "unknown <id> ")
+ << operands[wordIndex] << "used as initializer";
+ }
+ wordIndex++;
+ initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
+ }
+ if (wordIndex != operands.size()) {
+ return emitError(unknownLoc,
+ "found more operands than expected when deserializing "
+ "OpVariable instruction, only ")
+ << wordIndex << " of " << operands.size() << " processed";
+ }
+ auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
+ unknownLoc, opBuilder.getTypeAttr(type),
+ opBuilder.getStringAttr(variableName), initializer);
+
+ // Decorations.
+ if (decorations.count(variableID)) {
+ for (auto attr : decorations[variableID].getAttrs()) {
+ varOp.setAttr(attr.first, attr.second);
+ }
+ }
+ globalVariableMap[variableID] = varOp;
+ return success();
+}
+
LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc, "OpName needs at least 2 operands");
@@ -887,6 +979,11 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return success();
}
break;
+ case spirv::Opcode::OpVariable:
+ if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
+ return processGlobalVariable(operands);
+ }
+ break;
case spirv::Opcode::OpName:
return processName(operands);
case spirv::Opcode::OpTypeVoid:
@@ -954,18 +1051,19 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
"and OpFunction with <id> ")
<< fnID << ": " << fnName << " vs. " << parsedFunc.getName();
}
- SmallVector<Value *, 4> interface;
+ SmallVector<Attribute, 4> interface;
while (wordIndex < words.size()) {
- auto arg = getValue(words[wordIndex]);
+ auto arg = getVariable(words[wordIndex]);
if (!arg) {
return emitError(unknownLoc, "undefined result <id> ")
<< words[wordIndex] << " while decoding OpEntryPoint";
}
- interface.push_back(arg);
+ interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
wordIndex++;
}
- opBuilder.create<spirv::EntryPointOp>(
- unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface);
+ opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+ opBuilder.getSymbolRefAttr(fnName),
+ opBuilder.getArrayAttr(interface));
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index d06363a1a8c..575d995bf45 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -125,9 +125,19 @@ private:
return funcIDMap.lookup(fnName);
}
+ uint32_t findVariableID(StringRef varName) const {
+ return globalVarIDMap.lookup(varName);
+ }
+
+ /// Emit OpName for the given `resultID`.
+ LogicalResult processName(uint32_t resultID, StringRef name);
+
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
+ /// Process a SPIR-V GlobalVariableOp
+ LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
+
/// Process attributes that translate to decorations on the result <id>
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
@@ -215,6 +225,9 @@ private:
uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
+ /// Process spv.addressOf operations.
+ LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
+
/// Main dispatch method for serializing an operation.
LogicalResult processOperation(Operation *op);
@@ -265,6 +278,9 @@ private:
/// Map from FuncOps name to <id>s.
llvm::StringMap<uint32_t> funcIDMap;
+ /// Map from GlobalVariableOps name to <id>s
+ llvm::StringMap<uint32_t> globalVarIDMap;
+
/// Map from results of normal operations to their <id>s
DenseMap<Value *, uint32_t> valueIDMap;
};
@@ -372,6 +388,15 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
+LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
+ SmallVector<uint32_t, 4> nameOperands;
+ nameOperands.push_back(resultID);
+ if (failed(encodeStringLiteralInto(nameOperands, name))) {
+ return failure();
+ }
+ return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+}
+
namespace {
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
@@ -416,10 +441,9 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
// Add function name.
- SmallVector<uint32_t, 4> nameOperands;
- nameOperands.push_back(funcID);
- encodeStringLiteralInto(nameOperands, op.getName());
- encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+ if (failed(processName(funcID, op.getName()))) {
+ return failure();
+ }
// Declare the parameters.
for (auto arg : op.getArguments()) {
@@ -450,6 +474,61 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
}
+LogicalResult
+Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
+ // Get TypeID.
+ uint32_t resultTypeID = 0;
+ SmallVector<StringRef, 4> elidedAttrs;
+ if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
+ return failure();
+ }
+ elidedAttrs.push_back("type");
+ SmallVector<uint32_t, 4> operands;
+ operands.push_back(resultTypeID);
+ auto resultID = getNextID();
+
+ // Encode the name.
+ auto varName = varOp.sym_name();
+ elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+ if (failed(processName(resultID, varName))) {
+ return failure();
+ }
+ globalVarIDMap[varName] = resultID;
+ operands.push_back(resultID);
+
+ // Encode StorageClass.
+ operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
+
+ // Encode initialization.
+ if (auto initializer = varOp.initializer()) {
+ auto initializerID = findVariableID(initializer.getValue());
+ if (!initializerID) {
+ return emitError(varOp.getLoc(),
+ "invalid usage of undefined variable as initializer");
+ }
+ operands.push_back(initializerID);
+ elidedAttrs.push_back("initializer");
+ }
+
+ if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
+ operands))) {
+ elidedAttrs.push_back("initializer");
+ return failure();
+ }
+
+ // Encode decorations.
+ for (auto attr : varOp.getAttrs()) {
+ if (llvm::any_of(elidedAttrs,
+ [&](StringRef elided) { return attr.first.is(elided); })) {
+ continue;
+ }
+ if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
@@ -912,6 +991,17 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
// Operation
//===----------------------------------------------------------------------===//
+LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
+ auto varName = addressOfOp.variable();
+ auto variableID = findVariableID(varName);
+ if (!variableID) {
+ return addressOfOp.emitError("unknown result <id> for variable ")
+ << varName;
+ }
+ valueIDMap[addressOfOp.pointer()] = variableID;
+ return success();
+}
+
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
@@ -924,6 +1014,12 @@ LogicalResult Serializer::processOperation(Operation *op) {
if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
+ if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
+ return processGlobalVariableOp(varOp);
+ }
+ if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
+ return processAddressOfOp(addressOfOp);
+ }
return dispatchToAutogenSerialization(op);
}
@@ -947,14 +1043,16 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
encodeStringLiteralInto(operands, op.fn());
// Add the interface values.
- for (auto val : op.interface()) {
- auto id = findValueID(val);
- if (!id) {
- return op.emitError("referencing unintialized variable <id>. "
- "spv.EntryPoint is at the end of spv.module. All "
- "referenced variables should already be defined");
+ if (auto interface = op.interface()) {
+ for (auto var : interface.getValue()) {
+ auto id = findVariableID(var.cast<SymbolRefAttr>().getValue());
+ if (!id) {
+ return op.emitError("referencing undefined global variable."
+ "spv.EntryPoint is at the end of spv.module. All "
+ "referenced variables should already be defined");
+ }
+ operands.push_back(id);
}
- operands.push_back(id);
}
return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
operands);
OpenPOWER on IntegriCloud