diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | 23 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 347 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | 110 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 120 |
4 files changed, 485 insertions, 115 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 53a40dfa365..035de4f815d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -136,26 +136,26 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, signatureConverter, newFuncOp))) { return failure(); } - // Create spv.Variable ops for each of the arguments. These need to be bound - // by the runtime. For now use descriptor_set 0, and arg number as the binding - // number. + // Create spv.globalVariable ops for each of the arguments. These need to be + // bound by the runtime. For now use descriptor_set 0, and arg number as the + // binding number. auto module = funcOp.getParentOfType<spirv::ModuleOp>(); if (!module) { return funcOp.emitError("expected op to be within a spv.module"); } OpBuilder builder(module.getOperation()->getRegion(0)); - SmallVector<Value *, 4> interface; + SmallVector<Attribute, 4> interface; for (auto &convertedArgType : llvm::enumerate(signatureConverter.getConvertedTypes())) { - auto variableOp = builder.create<spirv::VariableOp>( - funcOp.getLoc(), convertedArgType.value(), - builder.getI32IntegerAttr( - static_cast<int32_t>(spirv::StorageClass::StorageBuffer)), - llvm::None); + std::string varName = funcOp.getName().str() + "_arg_" + + std::to_string(convertedArgType.index()); + auto variableOp = builder.create<spirv::GlobalVariableOp>( + funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()), + builder.getStringAttr(varName), nullptr); variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0)); variableOp.setAttr("binding", builder.getI32IntegerAttr(convertedArgType.index())); - interface.push_back(variableOp.getResult()); + interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name())); } // Create an entry point instruction for this function. // TODO(ravishankarm) : Add execution mode for the entry function @@ -164,7 +164,8 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, funcOp.getLoc(), builder.getI32IntegerAttr( static_cast<int32_t>(spirv::ExecutionModel::GLCompute)), - builder.getSymbolRefAttr(newFuncOp.getName()), interface); + builder.getSymbolRefAttr(newFuncOp.getName()), + builder.getArrayAttr(interface)); return success(); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 4bea441c366..9947c0254a9 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -32,11 +32,15 @@ using namespace mlir; // TODO(antiagainst): generate these strings using ODS. static constexpr const char kAlignmentAttrName[] = "alignment"; +static constexpr const char kFnNameAttrName[] = "fn"; static constexpr const char kIndicesAttrName[] = "indices"; +static constexpr const char kInitializerAttrName[] = "initializer"; +static constexpr const char kInterfaceAttrName[] = "interface"; static constexpr const char kIsSpecConstName[] = "is_spec_const"; +static constexpr const char kTypeAttrName[] = "type"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; -static constexpr const char kFnNameAttrName[] = "fn"; +static constexpr const char kVariableAttrName[] = "variable"; //===----------------------------------------------------------------------===// // Common utility functions @@ -239,6 +243,71 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) { printer->printOptionalAttrDict(op->getAttrs()); } +static ParseResult parseVariableDecorations(OpAsmParser *parser, + OperationState *state) { + auto builtInName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + if (succeeded(parser->parseOptionalKeyword("bind"))) { + Attribute set, binding; + // Parse optional descriptor binding + auto descriptorSetName = convertToSnakeCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + Type i32Type = parser->getBuilder().getIntegerType(32); + if (parser->parseLParen() || + parser->parseAttribute(set, i32Type, descriptorSetName, + state->attributes) || + parser->parseComma() || + parser->parseAttribute(binding, i32Type, bindingName, + state->attributes) || + parser->parseRParen()) { + return failure(); + } + } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) { + StringAttr builtIn; + if (parser->parseLParen() || + parser->parseAttribute(builtIn, Type(), builtInName, + state->attributes) || + parser->parseRParen()) { + return failure(); + } + } + + // Parse other attributes + if (parser->parseOptionalAttributeDict(state->attributes)) + return failure(); + + return success(); +} + +static void printVariableDecorations(Operation *op, OpAsmPrinter *printer, + SmallVectorImpl<StringRef> &elidedAttrs) { + // Print optional descriptor binding + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName); + auto binding = op->getAttrOfType<IntegerAttr>(bindingName); + if (descriptorSet && binding) { + elidedAttrs.push_back(descriptorSetName); + elidedAttrs.push_back(bindingName); + *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() + << ")"; + } + + // Print BuiltIn attribute if present + auto builtInName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) { + *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; + elidedAttrs.push_back(builtInName); + } + + printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// @@ -363,6 +432,53 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) { } //===----------------------------------------------------------------------===// +// spv._address_of +//===----------------------------------------------------------------------===// + +static ParseResult parseAddressOfOp(OpAsmParser *parser, + OperationState *state) { + SymbolRefAttr varRefAttr; + Type type; + if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName, + state->attributes) || + parser->parseColonType(type)) { + return failure(); + } + auto ptrType = type.dyn_cast<spirv::PointerType>(); + if (!ptrType) { + return parser->emitError(parser->getCurrentLocation(), + "expected spv.ptr type"); + } + state->addTypes(ptrType); + return success(); +} + +static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) { + SmallVector<StringRef, 4> elidedAttrs; + *printer << spirv::AddressOfOp::getOperationName(); + + // Print symbol name. + *printer << " @" << addressOfOp.variable(); + + // Print the type. + *printer << " : " << addressOfOp.pointer(); +} + +static LogicalResult verify(spirv::AddressOfOp addressOfOp) { + auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>(); + auto varOp = + moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable()); + if (!varOp) { + return addressOfOp.emitError("expected spv.globalVariable symbol"); + } + if (addressOfOp.pointer()->getType() != varOp.type()) { + return addressOfOp.emitError( + "mismatch in result type and type of global variable referenced"); + } + return success(); +} + +//===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// @@ -541,18 +657,28 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser, SmallVector<OpAsmParser::OperandType, 0> identifiers; SmallVector<Type, 0> idTypes; - Attribute fn; - auto loc = parser->getCurrentLocation(); - + SymbolRefAttr fn; if (parseEnumAttribute(execModel, parser, state) || - parser->parseAttribute(fn, kFnNameAttrName, state->attributes) || - parser->parseTrailingOperandList(identifiers) || - parser->parseOptionalColonTypeList(idTypes) || - parser->resolveOperands(identifiers, idTypes, loc, state->operands)) { + parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) { return failure(); } - if (!fn.isa<SymbolRefAttr>()) { - return parser->emitError(loc, "expected symbol reference attribute"); + + if (!parser->parseOptionalComma()) { + // Parse the interface variables + SmallVector<Attribute, 4> interfaceVars; + do { + // The name of the interface variable attribute isnt important + auto attrName = "var_symbol"; + SymbolRefAttr var; + SmallVector<NamedAttribute, 1> attrs; + if (parser->parseAttribute(var, Type(), attrName, attrs)) { + return failure(); + } + interfaceVars.push_back(var); + } while (!parser->parseOptionalComma()); + state->attributes.push_back( + {parser->getBuilder().getIdentifier(kInterfaceAttrName), + parser->getBuilder().getArrayAttr(interfaceVars)}); } return success(); } @@ -561,27 +687,16 @@ static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) { *printer << spirv::EntryPointOp::getOperationName() << " \"" << stringifyExecutionModel(entryPointOp.execution_model()) << "\" @" << entryPointOp.fn(); - if (!entryPointOp.getNumOperands()) { - return; + if (auto interface = entryPointOp.interface()) { + *printer << ", "; + mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(), + [&](Attribute a) { printer->printAttribute(a); }); } - *printer << ", "; - mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(), - [&](Value *a) { printer->printOperand(a); }); - *printer << " : "; - mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(), - [&](const Value *a) { *printer << a->getType(); }); } static LogicalResult verify(spirv::EntryPointOp entryPointOp) { - // Verify that all the interface ops are created from VariableOp - for (auto interface : entryPointOp.interface()) { - if (!llvm::isa_and_nonnull<spirv::VariableOp>(interface->getDefiningOp())) { - return entryPointOp.emitOpError("interface operands to entry point must " - "be generated from a variable op"); - } - // TODO: Before version 1.4 the variables can only have storage_class of - // Input or Output. That needs to be verified. - } + // Checks for fn and interface symbol reference are done in spirv::ModuleOp + // verification. return success(); } @@ -628,6 +743,95 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) { } //===----------------------------------------------------------------------===// +// spv.globalVariable +//===----------------------------------------------------------------------===// + +static ParseResult parseGlobalVariableOp(OpAsmParser *parser, + OperationState *state) { + // Parse variable type. + TypeAttr typeAttr; + auto loc = parser->getCurrentLocation(); + if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName, + state->attributes)) { + return failure(); + } + auto ptrType = typeAttr.getValue().dyn_cast<spirv::PointerType>(); + if (!ptrType) { + return parser->emitError(loc, "expected spv.ptr type"); + } + + // Parse variable name. + StringAttr nameAttr; + if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + state->attributes)) { + return failure(); + } + + // Parse optional initializer + if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) { + SymbolRefAttr initSymbol; + if (parser->parseLParen() || + parser->parseAttribute(initSymbol, Type(), kInitializerAttrName, + state->attributes) || + parser->parseRParen()) + return failure(); + } + + if (parseVariableDecorations(parser, state)) { + return failure(); + } + + return success(); +} + +static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) { + auto *op = varOp.getOperation(); + SmallVector<StringRef, 4> elidedAttrs{ + spirv::attributeName<spirv::StorageClass>()}; + *printer << spirv::GlobalVariableOp::getOperationName(); + + // Print variable type. + *printer << " " << varOp.type(); + elidedAttrs.push_back(kTypeAttrName); + + // Print variable name. + *printer << " @" << varOp.sym_name(); + elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); + + // Print optional initializer + if (auto initializer = varOp.initializer()) { + *printer << " " << kInitializerAttrName << "(@" << initializer.getValue() + << ")"; + elidedAttrs.push_back(kInitializerAttrName); + } + printVariableDecorations(op, printer, elidedAttrs); +} + +static LogicalResult verify(spirv::GlobalVariableOp varOp) { + // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the + // object. It cannot be Generic. It must be the same as the Storage Class + // operand of the Result Type." + if (varOp.storageClass() == spirv::StorageClass::Generic) + return varOp.emitOpError("storage class cannot be 'Generic'"); + + if (auto initializer = + varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) { + // Get the module + auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>(); + // TODO: Currently only variable initialization with other variables is + // supported. They could be constants as well, but this needs module-level + // constants to have symbol name as well. + if (!moduleOp.lookupSymbol<spirv::GlobalVariableOp>( + initializer.getValue())) { + return varOp.emitOpError( + "initializer must be result of a spv.globalVariable op"); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// @@ -773,13 +977,33 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { for (auto &op : body) { if (op.getDialect() == dialect) { // For EntryPoint op, check that the function and execution model is not - // duplicated in EntryPointOps + // duplicated in EntryPointOps. Also verify that the interface specified + // comes from globalVariables here to make this check cheaper. if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) { auto funcOp = table.lookup<FuncOp>(entryPointOp.fn()); if (!funcOp) { return entryPointOp.emitError("function '") << entryPointOp.fn() << "' not found in 'spv.module'"; } + if (auto interface = entryPointOp.interface()) { + for (auto varRef : interface.getValue().getValue()) { + auto varSymRef = varRef.dyn_cast<SymbolRefAttr>(); + if (!varSymRef) { + return entryPointOp.emitError( + "expected symbol reference for interface " + "specification instead of '") + << varRef; + } + auto variableOp = + table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue()); + if (!variableOp) { + return entryPointOp.emitError("expected spv.globalVariable " + "symbol reference instead of'") + << varSymRef << "'"; + } + } + } + auto key = std::pair<FuncOp, spirv::ExecutionModel>( funcOp, entryPointOp.execution_model()); auto entryPtIt = entryPoints.find(key); @@ -898,42 +1122,9 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { return failure(); } - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); - if (succeeded(parser->parseOptionalKeyword("bind"))) { - Attribute set, binding; - // Parse optional descriptor binding - auto descriptorSetName = convertToSnakeCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); - Type i32Type = parser->getBuilder().getIntegerType(32); - if (parser->parseLParen() || - parser->parseAttribute(set, i32Type, descriptorSetName, - state->attributes) || - parser->parseComma() || - parser->parseAttribute(binding, i32Type, bindingName, - state->attributes) || - parser->parseRParen()) { - return failure(); - } - } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) { - Attribute builtIn; - if (parser->parseLParen() || - parser->parseAttribute(builtIn, Type(), builtInName, - state->attributes) || - parser->parseRParen()) { - return failure(); - } - if (!builtIn.isa<StringAttr>()) { - return parser->emitError(parser->getCurrentLocation(), - "expected string value for built_in attribute"); - } - } - - // Parse other attributes - if (parser->parseOptionalAttributeDict(state->attributes)) + if (parseVariableDecorations(parser, state)) { return failure(); + } // Parse result pointer type Type type; @@ -976,29 +1167,8 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { *printer << ")"; } - // Print optional descriptor binding - auto descriptorSetName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); - auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName); - auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName); - if (descriptorSet && binding) { - elidedAttrs.push_back(descriptorSetName); - elidedAttrs.push_back(bindingName); - *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() - << ")"; - } - - // Print BuiltIn attribute if present - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); - if (auto builtin = varOp.getAttrOfType<StringAttr>(builtInName)) { - *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; - elidedAttrs.push_back(builtInName); - } + printVariableDecorations(op, printer, elidedAttrs); - printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); *printer << " : " << varOp.getType(); } @@ -1006,8 +1176,11 @@ static LogicalResult verify(spirv::VariableOp varOp) { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." - if (varOp.storage_class() == spirv::StorageClass::Generic) - return varOp.emitOpError("storage class cannot be 'Generic'"); + if (varOp.storage_class() != spirv::StorageClass::Function) { + return varOp.emitOpError( + "can only be used to model function-level variables. Use " + "spv.globalVariable for module-level variables."); + } auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>(); if (varOp.storage_class() != pointerType.getStorageClass()) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 1aad7173dc6..a3d71eda5d9 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -90,9 +90,20 @@ private: /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef<uint32_t> operands); + /// Process the OpVariable instructions at current `offset` into `binary`. It + /// is expected that this method is used for variables that are to be defined + /// at module scope and will be deserialized into a spv.globalVariable + /// instruction. + LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands); + /// Get the FuncOp associated with a result <id> of OpFunction. FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } + /// Get the global variable associated with a result <id> of OpVariable + spirv::GlobalVariableOp getVariable(uint32_t id) { + return globalVariableMap.lookup(id); + } + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// @@ -138,7 +149,15 @@ private: //===--------------------------------------------------------------------===// /// Get the Value associated with a result <id>. - Value *getValue(uint32_t id) { return valueMap.lookup(id); } + Value *getValue(uint32_t id) { + if (auto varOp = getVariable(id)) { + auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( + unknownLoc, varOp.type(), + opBuilder.getSymbolRefAttr(varOp.getOperation())); + return addressOfOp.pointer(); + } + return valueMap.lookup(id); + } /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if @@ -198,6 +217,9 @@ private: // Result <id> to function mapping. DenseMap<uint32_t, FuncOp> funcMap; + // Result <id> to variable mapping; + DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; + // Result <id> to value mapping. DenseMap<uint32_t, Value *> valueMap; @@ -452,6 +474,76 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { return success(); } +LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { + unsigned wordIndex = 0; + if (operands.size() < 3) { + return emitError( + unknownLoc, + "OpVariable needs at least 3 operands, type, <id> and storage class"); + } + + // Result Type. + auto type = getType(operands[wordIndex]); + if (!type) { + return emitError(unknownLoc, "unknown result type <id> : ") + << operands[wordIndex]; + } + auto ptrType = type.dyn_cast<spirv::PointerType>(); + if (!ptrType) { + return emitError(unknownLoc, + "expected a result type <id> to be a spv.ptr, found : ") + << type; + } + wordIndex++; + + // Result <id>. + auto variableID = operands[wordIndex]; + auto variableName = nameMap.lookup(variableID).str(); + if (variableName.empty()) { + variableName = "spirv_var_" + std::to_string(variableID); + } + wordIndex++; + + // Storage class. + auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); + if (ptrType.getStorageClass() != storageClass) { + return emitError(unknownLoc, "mismatch in storage class of pointer type ") + << type << " and that specified in OpVariable instruction : " + << stringifyStorageClass(storageClass); + } + wordIndex++; + + // Initializer. + SymbolRefAttr initializer = nullptr; + if (wordIndex < operands.size()) { + auto initializerOp = getVariable(operands[wordIndex]); + if (!initializerOp) { + return emitError(unknownLoc, "unknown <id> ") + << operands[wordIndex] << "used as initializer"; + } + wordIndex++; + initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); + } + if (wordIndex != operands.size()) { + return emitError(unknownLoc, + "found more operands than expected when deserializing " + "OpVariable instruction, only ") + << wordIndex << " of " << operands.size() << " processed"; + } + auto varOp = opBuilder.create<spirv::GlobalVariableOp>( + unknownLoc, opBuilder.getTypeAttr(type), + opBuilder.getStringAttr(variableName), initializer); + + // Decorations. + if (decorations.count(variableID)) { + for (auto attr : decorations[variableID].getAttrs()) { + varOp.setAttr(attr.first, attr.second); + } + } + globalVariableMap[variableID] = varOp; + return success(); +} + LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpName needs at least 2 operands"); @@ -887,6 +979,11 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return success(); } break; + case spirv::Opcode::OpVariable: + if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { + return processGlobalVariable(operands); + } + break; case spirv::Opcode::OpName: return processName(operands); case spirv::Opcode::OpTypeVoid: @@ -954,18 +1051,19 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { "and OpFunction with <id> ") << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); } - SmallVector<Value *, 4> interface; + SmallVector<Attribute, 4> interface; while (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); + auto arg = getVariable(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "undefined result <id> ") << words[wordIndex] << " while decoding OpEntryPoint"; } - interface.push_back(arg); + interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } - opBuilder.create<spirv::EntryPointOp>( - unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface); + opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model, + opBuilder.getSymbolRefAttr(fnName), + opBuilder.getArrayAttr(interface)); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index d06363a1a8c..575d995bf45 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -125,9 +125,19 @@ private: return funcIDMap.lookup(fnName); } + uint32_t findVariableID(StringRef varName) const { + return globalVarIDMap.lookup(varName); + } + + /// Emit OpName for the given `resultID`. + LogicalResult processName(uint32_t resultID, StringRef name); + /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); + /// Process a SPIR-V GlobalVariableOp + LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); + /// Process attributes that translate to decorations on the result <id> LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); @@ -215,6 +225,9 @@ private: uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); } + /// Process spv.addressOf operations. + LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); + /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); @@ -265,6 +278,9 @@ private: /// Map from FuncOps name to <id>s. llvm::StringMap<uint32_t> funcIDMap; + /// Map from GlobalVariableOps name to <id>s + llvm::StringMap<uint32_t> globalVarIDMap; + /// Map from results of normal operations to their <id>s DenseMap<Value *, uint32_t> valueIDMap; }; @@ -372,6 +388,15 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); } +LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { + SmallVector<uint32_t, 4> nameOperands; + nameOperands.push_back(resultID); + if (failed(encodeStringLiteralInto(nameOperands, name))) { + return failure(); + } + return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); +} + namespace { template <> LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( @@ -416,10 +441,9 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands); // Add function name. - SmallVector<uint32_t, 4> nameOperands; - nameOperands.push_back(funcID); - encodeStringLiteralInto(nameOperands, op.getName()); - encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); + if (failed(processName(funcID, op.getName()))) { + return failure(); + } // Declare the parameters. for (auto arg : op.getArguments()) { @@ -450,6 +474,61 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); } +LogicalResult +Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { + // Get TypeID. + uint32_t resultTypeID = 0; + SmallVector<StringRef, 4> elidedAttrs; + if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { + return failure(); + } + elidedAttrs.push_back("type"); + SmallVector<uint32_t, 4> operands; + operands.push_back(resultTypeID); + auto resultID = getNextID(); + + // Encode the name. + auto varName = varOp.sym_name(); + elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); + if (failed(processName(resultID, varName))) { + return failure(); + } + globalVarIDMap[varName] = resultID; + operands.push_back(resultID); + + // Encode StorageClass. + operands.push_back(static_cast<uint32_t>(varOp.storageClass())); + + // Encode initialization. + if (auto initializer = varOp.initializer()) { + auto initializerID = findVariableID(initializer.getValue()); + if (!initializerID) { + return emitError(varOp.getLoc(), + "invalid usage of undefined variable as initializer"); + } + operands.push_back(initializerID); + elidedAttrs.push_back("initializer"); + } + + if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable, + operands))) { + elidedAttrs.push_back("initializer"); + return failure(); + } + + // Encode decorations. + for (auto attr : varOp.getAttrs()) { + if (llvm::any_of(elidedAttrs, + [&](StringRef elided) { return attr.first.is(elided); })) { + continue; + } + if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// @@ -912,6 +991,17 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, // Operation //===----------------------------------------------------------------------===// +LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { + auto varName = addressOfOp.variable(); + auto variableID = findVariableID(varName); + if (!variableID) { + return addressOfOp.emitError("unknown result <id> for variable ") + << varName; + } + valueIDMap[addressOfOp.pointer()] = variableID; + return success(); +} + LogicalResult Serializer::processOperation(Operation *op) { // First dispatch the methods that do not directly mirror an operation from // the SPIR-V spec @@ -924,6 +1014,12 @@ LogicalResult Serializer::processOperation(Operation *op) { if (isa<spirv::ModuleEndOp>(op)) { return success(); } + if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { + return processGlobalVariableOp(varOp); + } + if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) { + return processAddressOfOp(addressOfOp); + } return dispatchToAutogenSerialization(op); } @@ -947,14 +1043,16 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { encodeStringLiteralInto(operands, op.fn()); // Add the interface values. - for (auto val : op.interface()) { - auto id = findValueID(val); - if (!id) { - return op.emitError("referencing unintialized variable <id>. " - "spv.EntryPoint is at the end of spv.module. All " - "referenced variables should already be defined"); + if (auto interface = op.interface()) { + for (auto var : interface.getValue()) { + auto id = findVariableID(var.cast<SymbolRefAttr>().getValue()); + if (!id) { + return op.emitError("referencing undefined global variable." + "spv.EntryPoint is at the end of spv.module. All " + "referenced variables should already be defined"); + } + operands.push_back(id); } - operands.push_back(id); } return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); |