summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td4
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp354
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp209
3 files changed, 312 insertions, 255 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 83f474cd780..66544f0730a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -539,8 +539,8 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
// Controls whether the (de)serialization method is generated automatically or
// not. This results in generation of the following methods:
//
- // template<typename OpTy> Serialization::processOp(OpTy op)
- // template<typename OpTy> Deserialization::processOp(ArrayRef<uint32_t>)
+ // template<typename OpTy> Serializer::processOp(OpTy op)
+ // template<typename OpTy> Deserializer::processOp(ArrayRef<uint32_t>)
//
// If the auto generation is disabled (set to 0), then manual implementation
// of a specialization of these methods is required.
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 676ef8ec871..1d3a934f6f8 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -63,35 +63,63 @@ public:
Optional<spirv::ModuleOp> collect();
private:
- /// Get type for a given result <id>
- Type getType(uint32_t id) { return typeMap.lookup(id); }
+ //===--------------------------------------------------------------------===//
+ // Module structure
+ //===--------------------------------------------------------------------===//
- /// Get Value associated with a result <id>
- Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+ /// Initializes the `module` ModuleOp in this deserializer instance.
+ spirv::ModuleOp createModuleOp();
+
+ /// Processes SPIR-V module header in `binary`.
+ LogicalResult processHeader();
+
+ /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
+ LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
+
+ /// Processes the SPIR-V function at the current `offset` into `binary`.
+ /// The operands to the OpFunction instruction is passed in as ``operands`.
+ /// This method processes each instruction inside the function and dispatches
+ /// them to their handler method accordingly.
+ LogicalResult processFunction(ArrayRef<uint32_t> operands);
- // Check if a type is void
+ //===--------------------------------------------------------------------===//
+ // Type
+ //===--------------------------------------------------------------------===//
+
+ /// Gets type for a given result <id>.
+ Type getType(uint32_t id) { return typeMap.lookup(id); }
+
+ /// Returns true if the given `type` is for SPIR-V void type.
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
- /// Processes SPIR-V module header.
- LogicalResult processHeader();
+ /// Processes a SPIR-V type instruction with given `opcode` and `operands` and
+ /// registers the type into `module`.
+ LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
- /// Deserialize a single instruction. The |opcode| and |operands| are returned
- /// after deserialization to the caller.
- LogicalResult deserializeInstruction(spirv::Opcode &opcode,
- ArrayRef<uint32_t> &operands);
+ LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
+
+ //===--------------------------------------------------------------------===//
+ // Instruction
+ //===--------------------------------------------------------------------===//
+
+ /// Get the Value associated with a result <id>.
+ Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+
+ /// Slices the first instruction out of `binary` and returns its opcode and
+ /// operands via `opcode` and `operands` respectively.
+ LogicalResult sliceInstruction(spirv::Opcode &opcode,
+ ArrayRef<uint32_t> &operands);
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
+ /// This method is the main entrance for handling SPIR-V instruction; it
+ /// checks the instruction opcode and dispatches to the corresponding handler.
LogicalResult processInstruction(spirv::Opcode opcode,
ArrayRef<uint32_t> operands);
- /// Processes a SPIR-V type instruction with given 'opcode' and 'operands'
- LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
- LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
-
/// Method to dispatch to the specialized deserialization function for an
- /// operation in SPIR-V dialect that is a mirror of an operation in the SPIR-V
- /// spec. This is auto-generated from ODS. Dispatch is handled for all
- /// operations in SPIR-V dialect that have hasOpcode == 1
+ /// operation in SPIR-V dialect that is a mirror of an instruction in the
+ /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
+ /// all operations in SPIR-V dialect that have hasOpcode == 1.
LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
ArrayRef<uint32_t> words);
@@ -101,20 +129,13 @@ private:
template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
return processOpImpl<OpTy>(words);
}
+
template <typename OpTy>
LogicalResult processOpImpl(ArrayRef<uint32_t> words) {
- return emitError(unknownLoc, "unsupported deserialization for op '")
- << OpTy::getOperationName() << "')";
+ return emitError(unknownLoc, "unsupported deserialization for ")
+ << OpTy::getOperationName() << " op";
}
- /// Process function objects in binary
- LogicalResult processFunction(ArrayRef<uint32_t> operands);
-
- LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
-
- /// Initializes the `module` ModuleOp in this deserializer instance.
- spirv::ModuleOp createModuleOp();
-
private:
/// The SPIR-V binary module.
ArrayRef<uint32_t> binary;
@@ -155,7 +176,7 @@ LogicalResult Deserializer::deserialize() {
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
- while (succeeded(deserializeInstruction(opcode, operands))) {
+ while (succeeded(sliceInstruction(opcode, operands))) {
if (failed(processInstruction(opcode, operands)))
return failure();
}
@@ -165,6 +186,20 @@ LogicalResult Deserializer::deserialize() {
Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
+//===----------------------------------------------------------------------===//
+// Module structure
+//===----------------------------------------------------------------------===//
+
+spirv::ModuleOp Deserializer::createModuleOp() {
+ Builder builder(context);
+ OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
+ // TODO(antiagainst): use target environment to select the version
+ state.addAttribute("major_version", builder.getI32IntegerAttr(1));
+ state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
+ spirv::ModuleOp::build(&builder, &state);
+ return llvm::cast<spirv::ModuleOp>(Operation::create(state));
+}
+
LogicalResult Deserializer::processHeader() {
if (binary.size() < spirv::kHeaderWordCount)
return emitError(unknownLoc,
@@ -178,113 +213,16 @@ LogicalResult Deserializer::processHeader() {
return success();
}
-LogicalResult
-Deserializer::deserializeInstruction(spirv::Opcode &opcode,
- ArrayRef<uint32_t> &operands) {
- auto binarySize = binary.size();
- if (curOffset >= binarySize) {
- return failure();
- }
- // For each instruction, get its word count from the first word to slice it
- // from the stream properly, and then dispatch to the instruction handler.
-
- uint32_t wordCount = binary[curOffset] >> 16;
- opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
-
- if (wordCount == 0)
- return emitError(unknownLoc, "word count cannot be zero");
-
- uint32_t nextOffset = curOffset + wordCount;
- if (nextOffset > binarySize)
- return emitError(unknownLoc, "insufficient words for the last instruction");
-
- operands = binary.slice(curOffset + 1, wordCount - 1);
- curOffset = nextOffset;
- return success();
-}
+LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2)
+ return emitError(unknownLoc, "OpMemoryModel must have two operands");
-LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
- assert(!operands.empty() && "No operands for processing function type");
- if (operands.size() == 1) {
- return emitError(unknownLoc, "missing return type for OpTypeFunction");
- }
- auto returnType = getType(operands[1]);
- if (!returnType) {
- return emitError(unknownLoc, "unknown return type in OpTypeFunction");
- }
- SmallVector<Type, 1> argTypes;
- for (size_t i = 2, e = operands.size(); i < e; ++i) {
- auto ty = getType(operands[i]);
- if (!ty) {
- return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
- }
- argTypes.push_back(ty);
- }
- ArrayRef<Type> returnTypes;
- if (!isVoidType(returnType)) {
- returnTypes = llvm::makeArrayRef(returnType);
- }
- typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
- return success();
-}
+ module->setAttr(
+ "addressing_model",
+ opBuilder.getI32IntegerAttr(bitwiseCast<int32_t>(operands.front())));
+ module->setAttr("memory_model", opBuilder.getI32IntegerAttr(
+ bitwiseCast<int32_t>(operands.back())));
-LogicalResult Deserializer::processType(spirv::Opcode opcode,
- ArrayRef<uint32_t> operands) {
- if (operands.empty()) {
- return emitError(unknownLoc, "type instruction with opcode ")
- << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
- }
- /// TODO: Types might be forward declared in some instructions and need to be
- /// handled appropriately.
- if (typeMap.count(operands[0])) {
- return emitError(unknownLoc, "duplicate definition for result <id> ")
- << operands[0];
- }
- switch (opcode) {
- case spirv::Opcode::OpTypeVoid:
- if (operands.size() != 1) {
- return emitError(unknownLoc, "OpTypeVoid must have no parameters");
- }
- typeMap[operands[0]] = NoneType::get(context);
- break;
- case spirv::Opcode::OpTypeFloat: {
- if (operands.size() != 2) {
- return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
- }
- Type floatTy;
- switch (operands[1]) {
- case 16:
- floatTy = opBuilder.getF16Type();
- break;
- case 32:
- floatTy = opBuilder.getF32Type();
- break;
- case 64:
- floatTy = opBuilder.getF64Type();
- break;
- default:
- return emitError(unknownLoc, "unsupported bitwdith ")
- << operands[1] << " with OpTypeFloat";
- }
- typeMap[operands[0]] = floatTy;
- } break;
- case spirv::Opcode::OpTypePointer: {
- if (operands.size() != 3) {
- return emitError(unknownLoc, "OpTypePointer must have two parameters");
- }
- auto pointeeType = getType(operands[2]);
- if (!pointeeType) {
- return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> : ")
- << operands[2];
- }
- auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
- typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
- } break;
- case spirv::Opcode::OpTypeFunction:
- return processFunctionType(operands);
- default:
- return emitError(unknownLoc, "unhandled type instruction");
- }
return success();
}
@@ -335,7 +273,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
auto argType = functionType.getInput(i);
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
- if (failed(deserializeInstruction(opcode, operands))) {
+ if (failed(sliceInstruction(opcode, operands))) {
return failure();
}
if (opcode != spirv::Opcode::OpFunctionParameter) {
@@ -372,7 +310,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
spirv::Opcode opcode;
ArrayRef<uint32_t> instOperands;
- while (succeeded(deserializeInstruction(opcode, instOperands)) &&
+ while (succeeded(sliceInstruction(opcode, instOperands)) &&
opcode != spirv::Opcode::OpFunctionEnd) {
if (failed(processInstruction(opcode, instOperands))) {
return failure();
@@ -388,8 +326,122 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return success();
}
-#define GET_DESERIALIZATION_FNS
-#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::processType(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands) {
+ if (operands.empty()) {
+ return emitError(unknownLoc, "type instruction with opcode ")
+ << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
+ }
+ /// TODO: Types might be forward declared in some instructions and need to be
+ /// handled appropriately.
+ if (typeMap.count(operands[0])) {
+ return emitError(unknownLoc, "duplicate definition for result <id> ")
+ << operands[0];
+ }
+ switch (opcode) {
+ case spirv::Opcode::OpTypeVoid:
+ if (operands.size() != 1) {
+ return emitError(unknownLoc, "OpTypeVoid must have no parameters");
+ }
+ typeMap[operands[0]] = NoneType::get(context);
+ break;
+ case spirv::Opcode::OpTypeFloat: {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
+ }
+ Type floatTy;
+ switch (operands[1]) {
+ case 16:
+ floatTy = opBuilder.getF16Type();
+ break;
+ case 32:
+ floatTy = opBuilder.getF32Type();
+ break;
+ case 64:
+ floatTy = opBuilder.getF64Type();
+ break;
+ default:
+ return emitError(unknownLoc, "unsupported bitwdith ")
+ << operands[1] << " with OpTypeFloat";
+ }
+ typeMap[operands[0]] = floatTy;
+ } break;
+ case spirv::Opcode::OpTypePointer: {
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "OpTypePointer must have two parameters");
+ }
+ auto pointeeType = getType(operands[2]);
+ if (!pointeeType) {
+ return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> : ")
+ << operands[2];
+ }
+ auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
+ typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
+ } break;
+ case spirv::Opcode::OpTypeFunction:
+ return processFunctionType(operands);
+ default:
+ return emitError(unknownLoc, "unhandled type instruction");
+ }
+ return success();
+}
+
+LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
+ assert(!operands.empty() && "No operands for processing function type");
+ if (operands.size() == 1) {
+ return emitError(unknownLoc, "missing return type for OpTypeFunction");
+ }
+ auto returnType = getType(operands[1]);
+ if (!returnType) {
+ return emitError(unknownLoc, "unknown return type in OpTypeFunction");
+ }
+ SmallVector<Type, 1> argTypes;
+ for (size_t i = 2, e = operands.size(); i < e; ++i) {
+ auto ty = getType(operands[i]);
+ if (!ty) {
+ return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
+ }
+ argTypes.push_back(ty);
+ }
+ ArrayRef<Type> returnTypes;
+ if (!isVoidType(returnType)) {
+ returnTypes = llvm::makeArrayRef(returnType);
+ }
+ typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Instruction
+//===----------------------------------------------------------------------===//
+
+LogicalResult Deserializer::sliceInstruction(spirv::Opcode &opcode,
+ ArrayRef<uint32_t> &operands) {
+ auto binarySize = binary.size();
+ if (curOffset >= binarySize) {
+ return failure();
+ }
+ // For each instruction, get its word count from the first word to slice it
+ // from the stream properly, and then dispatch to the instruction handler.
+
+ uint32_t wordCount = binary[curOffset] >> 16;
+ opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
+
+ if (wordCount == 0)
+ return emitError(unknownLoc, "word count cannot be zero");
+
+ uint32_t nextOffset = curOffset + wordCount;
+ if (nextOffset > binarySize)
+ return emitError(unknownLoc, "insufficient words for the last instruction");
+
+ operands = binary.slice(curOffset + 1, wordCount - 1);
+ curOffset = nextOffset;
+ return success();
+}
LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
ArrayRef<uint32_t> operands) {
@@ -410,28 +462,10 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return dispatchToAutogenDeserialization(opcode, operands);
}
-LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
- if (operands.size() != 2)
- return emitError(unknownLoc, "OpMemoryModel must have two operands");
-
- module->setAttr(
- "addressing_model",
- opBuilder.getI32IntegerAttr(bitwiseCast<int32_t>(operands.front())));
- module->setAttr("memory_model", opBuilder.getI32IntegerAttr(
- bitwiseCast<int32_t>(operands.back())));
-
- return success();
-}
-
-spirv::ModuleOp Deserializer::createModuleOp() {
- Builder builder(context);
- OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
- // TODO(antiagainst): use target environment to select the version
- state.addAttribute("major_version", builder.getI32IntegerAttr(1));
- state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
- spirv::ModuleOp::build(&builder, &state);
- return llvm::cast<spirv::ModuleOp>(Operation::create(state));
-}
+// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
+// various processOpImpl specializations.
+#define GET_DESERIALIZATION_FNS
+#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
MLIRContext *context) {
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index b1c5936ae1a..58c89aa219a 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -73,64 +73,81 @@ public:
void collect(SmallVectorImpl<uint32_t> &binary);
private:
+ uint32_t getNextID() { return nextID++; }
+
+ //===--------------------------------------------------------------------===//
+ // Module structure
+ //===--------------------------------------------------------------------===//
+
/// Creates SPIR-V module header in the given `header`.
LogicalResult processHeader();
LogicalResult processMemoryModel();
- // Method to dispatch type serialization
+ Optional<uint32_t> findFunctionID(Operation *op) const {
+ auto it = funcIDMap.find(op);
+ return it != funcIDMap.end() ? it->second : Optional<uint32_t>();
+ }
+
+ /// Processes a SPIR-V function op.
+ LogicalResult processFuncOp(FuncOp op);
+
+ //===--------------------------------------------------------------------===//
+ // Types
+ //===--------------------------------------------------------------------===//
+
+ Optional<uint32_t> findTypeID(Type type) const {
+ auto it = typeIDMap.find(type);
+ return it != typeIDMap.end() ? it->second : Optional<uint32_t>();
+ }
+
+ Type voidType() { return mlir::NoneType::get(module.getContext()); }
+
+ bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+
+ /// Main dispatch method for serializing a type. The result <id> of the
+ /// serialized type will be returned as `typeID`.
LogicalResult processType(Location loc, Type type, uint32_t &typeID);
- // Methods to serialize individual types
+ /// Method for preparing basic SPIR-V type serialization. Returns the type's
+ /// opcode and operands for the instruction via `typeEnum` and `operands`.
LogicalResult processBasicType(Location loc, Type type,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
+
LogicalResult processFunctionType(Location loc, FunctionType type,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
- // Main method to dispatch operation serialization
+ //===--------------------------------------------------------------------===//
+ // Operations
+ //===--------------------------------------------------------------------===//
+
+ Optional<uint32_t> findValueID(Value *val) const {
+ auto it = valueIDMap.find(val);
+ return it != valueIDMap.end() ? it->second : Optional<uint32_t>();
+ }
+
+ /// Main dispatch method for serializing an operation.
LogicalResult processOperation(Operation *op);
/// Method to dispatch to the serialization function for an operation in
- /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. This
- /// is auto-generated from ODS. Dispatch is handled for all operations in
- /// SPIR-V dialect that have hasOpcode == 1
+ /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec.
+ /// This is auto-generated from ODS. Dispatch is handled for all operations
+ /// in SPIR-V dialect that have hasOpcode == 1.
LogicalResult dispatchToAutogenSerialization(Operation *op);
/// Method to serialize an operation in the SPIR-V dialect that is a mirror of
/// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
- /// 1 and autogenSerialization == 1 in ODS
+ /// 1 and autogenSerialization == 1 in ODS.
template <typename OpTy> LogicalResult processOp(OpTy op) {
return processOpImpl(op);
}
+
template <typename OpTy> LogicalResult processOpImpl(OpTy op) {
return op.emitError("unsupported op serialization");
}
- // Methods to serialize individual operation types
- LogicalResult processFuncOp(FuncOp op);
-
- uint32_t getNextID() { return nextID++; }
-
- Optional<uint32_t> findTypeID(Type type) const {
- auto it = typeIDMap.find(type);
- return (it != typeIDMap.end() ? it->second : Optional<uint32_t>(None));
- }
-
- Optional<uint32_t> findValueID(Value *val) const {
- auto it = valueIDMap.find(val);
- return (it != valueIDMap.end() ? it->second : Optional<uint32_t>(None));
- }
-
- Optional<uint32_t> findFunctionID(Operation *op) const {
- auto it = funcIDMap.find(op);
- return (it != funcIDMap.end() ? it->second : Optional<uint32_t>(None));
- }
-
- Type voidType() { return mlir::NoneType::get(module.getContext()); }
- bool isVoidType(Type type) const { return type.isa<NoneType>(); }
-
private:
/// The SPIR-V module to be serialized.
spirv::ModuleOp module;
@@ -182,8 +199,6 @@ LogicalResult Serializer::serialize() {
}
void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
- // The number of words in the SPIR-V module header
-
auto moduleSize = header.size() + capabilities.size() + extensions.size() +
extendedSets.size() + memoryModel.size() +
entryPoints.size() + executionModes.size() +
@@ -205,6 +220,9 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functions.begin(), functions.end());
}
+//===----------------------------------------------------------------------===//
+// Module structure
+//===----------------------------------------------------------------------===//
LogicalResult Serializer::processHeader() {
// The serializer tool ID registered to the Khronos Group
@@ -245,6 +263,66 @@ LogicalResult Serializer::processMemoryModel() {
return success();
}
+LogicalResult Serializer::processFuncOp(FuncOp op) {
+ uint32_t fnTypeID = 0;
+ // Generate type of the function
+ processType(op.getLoc(), op.getType(), fnTypeID);
+
+ // Add the function definition
+ SmallVector<uint32_t, 4> operands;
+ uint32_t resTypeID = 0;
+ auto resultTypes = op.getType().getResults();
+ if (resultTypes.size() > 1) {
+ return emitError(op.getLoc(),
+ "cannot serialize function with multiple return types");
+ }
+ if (failed(processType(op.getLoc(),
+ (resultTypes.empty() ? voidType() : resultTypes[0]),
+ resTypeID))) {
+ return failure();
+ }
+ operands.push_back(resTypeID);
+ auto funcID = getNextID();
+ funcIDMap[op.getOperation()] = funcID;
+ operands.push_back(funcID);
+ // TODO : Support other function control options
+ operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
+ operands.push_back(fnTypeID);
+ buildInstruction(spirv::Opcode::OpFunction, operands, functions);
+
+ // Declare the parameters
+ for (auto arg : op.getArguments()) {
+ uint32_t argTypeID = 0;
+ if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
+ return failure();
+ }
+ auto argValueID = getNextID();
+ valueIDMap[arg] = argValueID;
+ buildInstruction(spirv::Opcode::OpFunctionParameter,
+ {argTypeID, argValueID}, functions);
+ }
+
+ // Process the body
+ if (op.isExternal()) {
+ return emitError(op.getLoc(), "external function is unhandled");
+ }
+
+ for (auto &b : op)
+ for (auto &op : b)
+ if (failed(processOperation(&op))) {
+ return failure();
+ }
+
+ // Insert Function End
+ buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
LogicalResult Serializer::processType(Location loc, Type type,
uint32_t &typeID) {
auto id = findTypeID(type);
@@ -316,6 +394,10 @@ Serializer::processFunctionType(Location loc, FunctionType type,
return success();
}
+//===----------------------------------------------------------------------===//
+// Operation
+//===----------------------------------------------------------------------===//
+
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
@@ -327,67 +409,8 @@ LogicalResult Serializer::processOperation(Operation *op) {
return dispatchToAutogenSerialization(op);
}
-LogicalResult Serializer::processFuncOp(FuncOp op) {
- uint32_t fnTypeID = 0;
- // Generate type of the function
- processType(op.getLoc(), op.getType(), fnTypeID);
-
- /// Add the function definition
- SmallVector<uint32_t, 4> operands;
- uint32_t resTypeID = 0;
- auto resultTypes = op.getType().getResults();
- if (resultTypes.size() > 1) {
- return emitError(op.getLoc(),
- "cannot serialize function with multiple return types");
- }
- if (failed(processType(op.getLoc(),
- (resultTypes.empty() ? voidType() : resultTypes[0]),
- resTypeID))) {
- return failure();
- }
- operands.push_back(resTypeID);
- auto funcID = getNextID();
- funcIDMap[op.getOperation()] = funcID;
- operands.push_back(funcID);
- /// TODO : Support other function control options
- operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
- operands.push_back(fnTypeID);
- buildInstruction(spirv::Opcode::OpFunction, operands, functions);
-
- // Declare the parameters
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- buildInstruction(spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID}, functions);
- }
-
- // Process the body
- if (!op.empty()) {
- for (auto &b : op) {
- for (auto &op : b) {
- if (failed(processOperation(&op))) {
- return failure();
- }
- }
- }
- }
-
- // Insert Function End
- buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions);
-
- // If the function body is empty return an error
- // TODO : Handle external functions
- if (op.empty()) {
- return emitError(op.getLoc(), "external function is unhandled");
- }
- return success();
-}
-
+// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
+// various processOpImpl specializations.
#define GET_SERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
OpenPOWER on IntegriCloud