diff options
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | 354 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 209 |
3 files changed, 312 insertions, 255 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 83f474cd780..66544f0730a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -539,8 +539,8 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> : // Controls whether the (de)serialization method is generated automatically or // not. This results in generation of the following methods: // - // template<typename OpTy> Serialization::processOp(OpTy op) - // template<typename OpTy> Deserialization::processOp(ArrayRef<uint32_t>) + // template<typename OpTy> Serializer::processOp(OpTy op) + // template<typename OpTy> Deserializer::processOp(ArrayRef<uint32_t>) // // If the auto generation is disabled (set to 0), then manual implementation // of a specialization of these methods is required. diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 676ef8ec871..1d3a934f6f8 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -63,35 +63,63 @@ public: Optional<spirv::ModuleOp> collect(); private: - /// Get type for a given result <id> - Type getType(uint32_t id) { return typeMap.lookup(id); } + //===--------------------------------------------------------------------===// + // Module structure + //===--------------------------------------------------------------------===// - /// Get Value associated with a result <id> - Value *getValue(uint32_t id) { return valueMap.lookup(id); } + /// Initializes the `module` ModuleOp in this deserializer instance. + spirv::ModuleOp createModuleOp(); + + /// Processes SPIR-V module header in `binary`. + LogicalResult processHeader(); + + /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`. + LogicalResult processMemoryModel(ArrayRef<uint32_t> operands); + + /// Processes the SPIR-V function at the current `offset` into `binary`. + /// The operands to the OpFunction instruction is passed in as ``operands`. + /// This method processes each instruction inside the function and dispatches + /// them to their handler method accordingly. + LogicalResult processFunction(ArrayRef<uint32_t> operands); - // Check if a type is void + //===--------------------------------------------------------------------===// + // Type + //===--------------------------------------------------------------------===// + + /// Gets type for a given result <id>. + Type getType(uint32_t id) { return typeMap.lookup(id); } + + /// Returns true if the given `type` is for SPIR-V void type. bool isVoidType(Type type) const { return type.isa<NoneType>(); } - /// Processes SPIR-V module header. - LogicalResult processHeader(); + /// Processes a SPIR-V type instruction with given `opcode` and `operands` and + /// registers the type into `module`. + LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands); - /// Deserialize a single instruction. The |opcode| and |operands| are returned - /// after deserialization to the caller. - LogicalResult deserializeInstruction(spirv::Opcode &opcode, - ArrayRef<uint32_t> &operands); + LogicalResult processFunctionType(ArrayRef<uint32_t> operands); + + //===--------------------------------------------------------------------===// + // Instruction + //===--------------------------------------------------------------------===// + + /// Get the Value associated with a result <id>. + Value *getValue(uint32_t id) { return valueMap.lookup(id); } + + /// Slices the first instruction out of `binary` and returns its opcode and + /// operands via `opcode` and `operands` respectively. + LogicalResult sliceInstruction(spirv::Opcode &opcode, + ArrayRef<uint32_t> &operands); /// Processes a SPIR-V instruction with the given `opcode` and `operands`. + /// This method is the main entrance for handling SPIR-V instruction; it + /// checks the instruction opcode and dispatches to the corresponding handler. LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef<uint32_t> operands); - /// Processes a SPIR-V type instruction with given 'opcode' and 'operands' - LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands); - LogicalResult processFunctionType(ArrayRef<uint32_t> operands); - /// Method to dispatch to the specialized deserialization function for an - /// operation in SPIR-V dialect that is a mirror of an operation in the SPIR-V - /// spec. This is auto-generated from ODS. Dispatch is handled for all - /// operations in SPIR-V dialect that have hasOpcode == 1 + /// operation in SPIR-V dialect that is a mirror of an instruction in the + /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for + /// all operations in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, ArrayRef<uint32_t> words); @@ -101,20 +129,13 @@ private: template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) { return processOpImpl<OpTy>(words); } + template <typename OpTy> LogicalResult processOpImpl(ArrayRef<uint32_t> words) { - return emitError(unknownLoc, "unsupported deserialization for op '") - << OpTy::getOperationName() << "')"; + return emitError(unknownLoc, "unsupported deserialization for ") + << OpTy::getOperationName() << " op"; } - /// Process function objects in binary - LogicalResult processFunction(ArrayRef<uint32_t> operands); - - LogicalResult processMemoryModel(ArrayRef<uint32_t> operands); - - /// Initializes the `module` ModuleOp in this deserializer instance. - spirv::ModuleOp createModuleOp(); - private: /// The SPIR-V binary module. ArrayRef<uint32_t> binary; @@ -155,7 +176,7 @@ LogicalResult Deserializer::deserialize() { spirv::Opcode opcode; ArrayRef<uint32_t> operands; - while (succeeded(deserializeInstruction(opcode, operands))) { + while (succeeded(sliceInstruction(opcode, operands))) { if (failed(processInstruction(opcode, operands))) return failure(); } @@ -165,6 +186,20 @@ LogicalResult Deserializer::deserialize() { Optional<spirv::ModuleOp> Deserializer::collect() { return module; } +//===----------------------------------------------------------------------===// +// Module structure +//===----------------------------------------------------------------------===// + +spirv::ModuleOp Deserializer::createModuleOp() { + Builder builder(context); + OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); + // TODO(antiagainst): use target environment to select the version + state.addAttribute("major_version", builder.getI32IntegerAttr(1)); + state.addAttribute("minor_version", builder.getI32IntegerAttr(0)); + spirv::ModuleOp::build(&builder, &state); + return llvm::cast<spirv::ModuleOp>(Operation::create(state)); +} + LogicalResult Deserializer::processHeader() { if (binary.size() < spirv::kHeaderWordCount) return emitError(unknownLoc, @@ -178,113 +213,16 @@ LogicalResult Deserializer::processHeader() { return success(); } -LogicalResult -Deserializer::deserializeInstruction(spirv::Opcode &opcode, - ArrayRef<uint32_t> &operands) { - auto binarySize = binary.size(); - if (curOffset >= binarySize) { - return failure(); - } - // For each instruction, get its word count from the first word to slice it - // from the stream properly, and then dispatch to the instruction handler. - - uint32_t wordCount = binary[curOffset] >> 16; - opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff); - - if (wordCount == 0) - return emitError(unknownLoc, "word count cannot be zero"); - - uint32_t nextOffset = curOffset + wordCount; - if (nextOffset > binarySize) - return emitError(unknownLoc, "insufficient words for the last instruction"); - - operands = binary.slice(curOffset + 1, wordCount - 1); - curOffset = nextOffset; - return success(); -} +LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { + if (operands.size() != 2) + return emitError(unknownLoc, "OpMemoryModel must have two operands"); -LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { - assert(!operands.empty() && "No operands for processing function type"); - if (operands.size() == 1) { - return emitError(unknownLoc, "missing return type for OpTypeFunction"); - } - auto returnType = getType(operands[1]); - if (!returnType) { - return emitError(unknownLoc, "unknown return type in OpTypeFunction"); - } - SmallVector<Type, 1> argTypes; - for (size_t i = 2, e = operands.size(); i < e; ++i) { - auto ty = getType(operands[i]); - if (!ty) { - return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); - } - argTypes.push_back(ty); - } - ArrayRef<Type> returnTypes; - if (!isVoidType(returnType)) { - returnTypes = llvm::makeArrayRef(returnType); - } - typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); - return success(); -} + module->setAttr( + "addressing_model", + opBuilder.getI32IntegerAttr(bitwiseCast<int32_t>(operands.front()))); + module->setAttr("memory_model", opBuilder.getI32IntegerAttr( + bitwiseCast<int32_t>(operands.back()))); -LogicalResult Deserializer::processType(spirv::Opcode opcode, - ArrayRef<uint32_t> operands) { - if (operands.empty()) { - return emitError(unknownLoc, "type instruction with opcode ") - << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; - } - /// TODO: Types might be forward declared in some instructions and need to be - /// handled appropriately. - if (typeMap.count(operands[0])) { - return emitError(unknownLoc, "duplicate definition for result <id> ") - << operands[0]; - } - switch (opcode) { - case spirv::Opcode::OpTypeVoid: - if (operands.size() != 1) { - return emitError(unknownLoc, "OpTypeVoid must have no parameters"); - } - typeMap[operands[0]] = NoneType::get(context); - break; - case spirv::Opcode::OpTypeFloat: { - if (operands.size() != 2) { - return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); - } - Type floatTy; - switch (operands[1]) { - case 16: - floatTy = opBuilder.getF16Type(); - break; - case 32: - floatTy = opBuilder.getF32Type(); - break; - case 64: - floatTy = opBuilder.getF64Type(); - break; - default: - return emitError(unknownLoc, "unsupported bitwdith ") - << operands[1] << " with OpTypeFloat"; - } - typeMap[operands[0]] = floatTy; - } break; - case spirv::Opcode::OpTypePointer: { - if (operands.size() != 3) { - return emitError(unknownLoc, "OpTypePointer must have two parameters"); - } - auto pointeeType = getType(operands[2]); - if (!pointeeType) { - return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> : ") - << operands[2]; - } - auto storageClass = static_cast<spirv::StorageClass>(operands[1]); - typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); - } break; - case spirv::Opcode::OpTypeFunction: - return processFunctionType(operands); - default: - return emitError(unknownLoc, "unhandled type instruction"); - } return success(); } @@ -335,7 +273,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { auto argType = functionType.getInput(i); spirv::Opcode opcode; ArrayRef<uint32_t> operands; - if (failed(deserializeInstruction(opcode, operands))) { + if (failed(sliceInstruction(opcode, operands))) { return failure(); } if (opcode != spirv::Opcode::OpFunctionParameter) { @@ -372,7 +310,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { spirv::Opcode opcode; ArrayRef<uint32_t> instOperands; - while (succeeded(deserializeInstruction(opcode, instOperands)) && + while (succeeded(sliceInstruction(opcode, instOperands)) && opcode != spirv::Opcode::OpFunctionEnd) { if (failed(processInstruction(opcode, instOperands))) { return failure(); @@ -388,8 +326,122 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { return success(); } -#define GET_DESERIALIZATION_FNS -#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + +LogicalResult Deserializer::processType(spirv::Opcode opcode, + ArrayRef<uint32_t> operands) { + if (operands.empty()) { + return emitError(unknownLoc, "type instruction with opcode ") + << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; + } + /// TODO: Types might be forward declared in some instructions and need to be + /// handled appropriately. + if (typeMap.count(operands[0])) { + return emitError(unknownLoc, "duplicate definition for result <id> ") + << operands[0]; + } + switch (opcode) { + case spirv::Opcode::OpTypeVoid: + if (operands.size() != 1) { + return emitError(unknownLoc, "OpTypeVoid must have no parameters"); + } + typeMap[operands[0]] = NoneType::get(context); + break; + case spirv::Opcode::OpTypeFloat: { + if (operands.size() != 2) { + return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); + } + Type floatTy; + switch (operands[1]) { + case 16: + floatTy = opBuilder.getF16Type(); + break; + case 32: + floatTy = opBuilder.getF32Type(); + break; + case 64: + floatTy = opBuilder.getF64Type(); + break; + default: + return emitError(unknownLoc, "unsupported bitwdith ") + << operands[1] << " with OpTypeFloat"; + } + typeMap[operands[0]] = floatTy; + } break; + case spirv::Opcode::OpTypePointer: { + if (operands.size() != 3) { + return emitError(unknownLoc, "OpTypePointer must have two parameters"); + } + auto pointeeType = getType(operands[2]); + if (!pointeeType) { + return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> : ") + << operands[2]; + } + auto storageClass = static_cast<spirv::StorageClass>(operands[1]); + typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); + } break; + case spirv::Opcode::OpTypeFunction: + return processFunctionType(operands); + default: + return emitError(unknownLoc, "unhandled type instruction"); + } + return success(); +} + +LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { + assert(!operands.empty() && "No operands for processing function type"); + if (operands.size() == 1) { + return emitError(unknownLoc, "missing return type for OpTypeFunction"); + } + auto returnType = getType(operands[1]); + if (!returnType) { + return emitError(unknownLoc, "unknown return type in OpTypeFunction"); + } + SmallVector<Type, 1> argTypes; + for (size_t i = 2, e = operands.size(); i < e; ++i) { + auto ty = getType(operands[i]); + if (!ty) { + return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); + } + argTypes.push_back(ty); + } + ArrayRef<Type> returnTypes; + if (!isVoidType(returnType)) { + returnTypes = llvm::makeArrayRef(returnType); + } + typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); + return success(); +} + +//===----------------------------------------------------------------------===// +// Instruction +//===----------------------------------------------------------------------===// + +LogicalResult Deserializer::sliceInstruction(spirv::Opcode &opcode, + ArrayRef<uint32_t> &operands) { + auto binarySize = binary.size(); + if (curOffset >= binarySize) { + return failure(); + } + // For each instruction, get its word count from the first word to slice it + // from the stream properly, and then dispatch to the instruction handler. + + uint32_t wordCount = binary[curOffset] >> 16; + opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff); + + if (wordCount == 0) + return emitError(unknownLoc, "word count cannot be zero"); + + uint32_t nextOffset = curOffset + wordCount; + if (nextOffset > binarySize) + return emitError(unknownLoc, "insufficient words for the last instruction"); + + operands = binary.slice(curOffset + 1, wordCount - 1); + curOffset = nextOffset; + return success(); +} LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, ArrayRef<uint32_t> operands) { @@ -410,28 +462,10 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return dispatchToAutogenDeserialization(opcode, operands); } -LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { - if (operands.size() != 2) - return emitError(unknownLoc, "OpMemoryModel must have two operands"); - - module->setAttr( - "addressing_model", - opBuilder.getI32IntegerAttr(bitwiseCast<int32_t>(operands.front()))); - module->setAttr("memory_model", opBuilder.getI32IntegerAttr( - bitwiseCast<int32_t>(operands.back()))); - - return success(); -} - -spirv::ModuleOp Deserializer::createModuleOp() { - Builder builder(context); - OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); - // TODO(antiagainst): use target environment to select the version - state.addAttribute("major_version", builder.getI32IntegerAttr(1)); - state.addAttribute("minor_version", builder.getI32IntegerAttr(0)); - spirv::ModuleOp::build(&builder, &state); - return llvm::cast<spirv::ModuleOp>(Operation::create(state)); -} +// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and +// various processOpImpl specializations. +#define GET_DESERIALIZATION_FNS +#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary, MLIRContext *context) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index b1c5936ae1a..58c89aa219a 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -73,64 +73,81 @@ public: void collect(SmallVectorImpl<uint32_t> &binary); private: + uint32_t getNextID() { return nextID++; } + + //===--------------------------------------------------------------------===// + // Module structure + //===--------------------------------------------------------------------===// + /// Creates SPIR-V module header in the given `header`. LogicalResult processHeader(); LogicalResult processMemoryModel(); - // Method to dispatch type serialization + Optional<uint32_t> findFunctionID(Operation *op) const { + auto it = funcIDMap.find(op); + return it != funcIDMap.end() ? it->second : Optional<uint32_t>(); + } + + /// Processes a SPIR-V function op. + LogicalResult processFuncOp(FuncOp op); + + //===--------------------------------------------------------------------===// + // Types + //===--------------------------------------------------------------------===// + + Optional<uint32_t> findTypeID(Type type) const { + auto it = typeIDMap.find(type); + return it != typeIDMap.end() ? it->second : Optional<uint32_t>(); + } + + Type voidType() { return mlir::NoneType::get(module.getContext()); } + + bool isVoidType(Type type) const { return type.isa<NoneType>(); } + + /// Main dispatch method for serializing a type. The result <id> of the + /// serialized type will be returned as `typeID`. LogicalResult processType(Location loc, Type type, uint32_t &typeID); - // Methods to serialize individual types + /// Method for preparing basic SPIR-V type serialization. Returns the type's + /// opcode and operands for the instruction via `typeEnum` and `operands`. LogicalResult processBasicType(Location loc, Type type, spirv::Opcode &typeEnum, SmallVectorImpl<uint32_t> &operands); + LogicalResult processFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl<uint32_t> &operands); - // Main method to dispatch operation serialization + //===--------------------------------------------------------------------===// + // Operations + //===--------------------------------------------------------------------===// + + Optional<uint32_t> findValueID(Value *val) const { + auto it = valueIDMap.find(val); + return it != valueIDMap.end() ? it->second : Optional<uint32_t>(); + } + + /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); /// Method to dispatch to the serialization function for an operation in - /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. This - /// is auto-generated from ODS. Dispatch is handled for all operations in - /// SPIR-V dialect that have hasOpcode == 1 + /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. + /// This is auto-generated from ODS. Dispatch is handled for all operations + /// in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenSerialization(Operation *op); /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == - /// 1 and autogenSerialization == 1 in ODS + /// 1 and autogenSerialization == 1 in ODS. template <typename OpTy> LogicalResult processOp(OpTy op) { return processOpImpl(op); } + template <typename OpTy> LogicalResult processOpImpl(OpTy op) { return op.emitError("unsupported op serialization"); } - // Methods to serialize individual operation types - LogicalResult processFuncOp(FuncOp op); - - uint32_t getNextID() { return nextID++; } - - Optional<uint32_t> findTypeID(Type type) const { - auto it = typeIDMap.find(type); - return (it != typeIDMap.end() ? it->second : Optional<uint32_t>(None)); - } - - Optional<uint32_t> findValueID(Value *val) const { - auto it = valueIDMap.find(val); - return (it != valueIDMap.end() ? it->second : Optional<uint32_t>(None)); - } - - Optional<uint32_t> findFunctionID(Operation *op) const { - auto it = funcIDMap.find(op); - return (it != funcIDMap.end() ? it->second : Optional<uint32_t>(None)); - } - - Type voidType() { return mlir::NoneType::get(module.getContext()); } - bool isVoidType(Type type) const { return type.isa<NoneType>(); } - private: /// The SPIR-V module to be serialized. spirv::ModuleOp module; @@ -182,8 +199,6 @@ LogicalResult Serializer::serialize() { } void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { - // The number of words in the SPIR-V module header - auto moduleSize = header.size() + capabilities.size() + extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + @@ -205,6 +220,9 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } +//===----------------------------------------------------------------------===// +// Module structure +//===----------------------------------------------------------------------===// LogicalResult Serializer::processHeader() { // The serializer tool ID registered to the Khronos Group @@ -245,6 +263,66 @@ LogicalResult Serializer::processMemoryModel() { return success(); } +LogicalResult Serializer::processFuncOp(FuncOp op) { + uint32_t fnTypeID = 0; + // Generate type of the function + processType(op.getLoc(), op.getType(), fnTypeID); + + // Add the function definition + SmallVector<uint32_t, 4> operands; + uint32_t resTypeID = 0; + auto resultTypes = op.getType().getResults(); + if (resultTypes.size() > 1) { + return emitError(op.getLoc(), + "cannot serialize function with multiple return types"); + } + if (failed(processType(op.getLoc(), + (resultTypes.empty() ? voidType() : resultTypes[0]), + resTypeID))) { + return failure(); + } + operands.push_back(resTypeID); + auto funcID = getNextID(); + funcIDMap[op.getOperation()] = funcID; + operands.push_back(funcID); + // TODO : Support other function control options + operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None)); + operands.push_back(fnTypeID); + buildInstruction(spirv::Opcode::OpFunction, operands, functions); + + // Declare the parameters + for (auto arg : op.getArguments()) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; + buildInstruction(spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}, functions); + } + + // Process the body + if (op.isExternal()) { + return emitError(op.getLoc(), "external function is unhandled"); + } + + for (auto &b : op) + for (auto &op : b) + if (failed(processOperation(&op))) { + return failure(); + } + + // Insert Function End + buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { auto id = findTypeID(type); @@ -316,6 +394,10 @@ Serializer::processFunctionType(Location loc, FunctionType type, return success(); } +//===----------------------------------------------------------------------===// +// Operation +//===----------------------------------------------------------------------===// + LogicalResult Serializer::processOperation(Operation *op) { // First dispatch the methods that do not directly mirror an operation from // the SPIR-V spec @@ -327,67 +409,8 @@ LogicalResult Serializer::processOperation(Operation *op) { return dispatchToAutogenSerialization(op); } -LogicalResult Serializer::processFuncOp(FuncOp op) { - uint32_t fnTypeID = 0; - // Generate type of the function - processType(op.getLoc(), op.getType(), fnTypeID); - - /// Add the function definition - SmallVector<uint32_t, 4> operands; - uint32_t resTypeID = 0; - auto resultTypes = op.getType().getResults(); - if (resultTypes.size() > 1) { - return emitError(op.getLoc(), - "cannot serialize function with multiple return types"); - } - if (failed(processType(op.getLoc(), - (resultTypes.empty() ? voidType() : resultTypes[0]), - resTypeID))) { - return failure(); - } - operands.push_back(resTypeID); - auto funcID = getNextID(); - funcIDMap[op.getOperation()] = funcID; - operands.push_back(funcID); - /// TODO : Support other function control options - operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None)); - operands.push_back(fnTypeID); - buildInstruction(spirv::Opcode::OpFunction, operands, functions); - - // Declare the parameters - for (auto arg : op.getArguments()) { - uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) { - return failure(); - } - auto argValueID = getNextID(); - valueIDMap[arg] = argValueID; - buildInstruction(spirv::Opcode::OpFunctionParameter, - {argTypeID, argValueID}, functions); - } - - // Process the body - if (!op.empty()) { - for (auto &b : op) { - for (auto &op : b) { - if (failed(processOperation(&op))) { - return failure(); - } - } - } - } - - // Insert Function End - buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions); - - // If the function body is empty return an error - // TODO : Handle external functions - if (op.empty()) { - return emitError(op.getLoc(), "external function is unhandled"); - } - return success(); -} - +// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and +// various processOpImpl specializations. #define GET_SERIALIZATION_FNS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" |