diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | 71 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 41 |
2 files changed, 112 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index a3d71eda5d9..412487d16f4 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" @@ -84,6 +85,9 @@ private: /// Method to process an OpDecorate instruction. LogicalResult processDecoration(ArrayRef<uint32_t> words); + // Method to process an OpMemberDecorate instruction. + LogicalResult processMemberDecoration(ArrayRef<uint32_t> words); + /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. /// This method processes each instruction inside the function and dispatches @@ -122,6 +126,8 @@ private: LogicalResult processFunctionType(ArrayRef<uint32_t> operands); + LogicalResult processStructType(ArrayRef<uint32_t> operands); + //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// @@ -232,6 +238,9 @@ private: // Result <id> to type decorations. DenseMap<uint32_t, uint32_t> typeDecorations; + // Result <id> to member decorations. + DenseMap<uint32_t, DenseMap<uint32_t, uint32_t>> memberDecorationMap; + // List of instructions that are processed in a defered fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function @@ -368,6 +377,23 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) { return success(); } +LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { + // The binary layout of OpMemberDecorate is different comparing to OpDecorate + if (words.size() != 4) { + return emitError(unknownLoc, "OpMemberDecorate must have 4 operands"); + } + + switch (static_cast<spirv::Decoration>(words[2])) { + case spirv::Decoration::Offset: + memberDecorationMap[words[0]][words[1]] = words[3]; + break; + default: + return emitError(unknownLoc, "unhandled OpMemberDecoration case: ") + << words[2]; + } + return success(); +} + LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { // Get the result type if (operands.size() != 4) { @@ -653,6 +679,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, return processArrayType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeStruct: + return processStructType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -722,6 +750,46 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { return success(); } +LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) { + // TODO(ravishankarm) : Regarding to the spec spv.struct must support zero + // amount of members. + if (operands.size() < 2) { + return emitError(unknownLoc, "OpTypeStruct must have at least 2 operand"); + } + + SmallVector<Type, 0> memberTypes; + for (auto op : llvm::drop_begin(operands, 1)) { + Type memberType = getType(op); + if (!memberType) { + return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") + << op; + } + memberTypes.push_back(memberType); + } + + SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo; + // Check for layoutinfo + auto memberDecorationIt = memberDecorationMap.find(operands[0]); + if (memberDecorationIt != memberDecorationMap.end()) { + // Each member must have an offset + const auto &offsetDecorationMap = memberDecorationIt->second; + auto offsetDecorationMapEnd = offsetDecorationMap.end(); + for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { + // Check that specific member has an offset + auto offsetIt = offsetDecorationMap.find(memberIndex); + if (offsetIt == offsetDecorationMapEnd) { + return emitError(unknownLoc, "OpTypeStruct with <id> ") + << operands[0] << " must have an offset for " << memberIndex + << "-th member"; + } + layoutInfo.push_back( + static_cast<spirv::StructType::LayoutInfo>(offsetIt->second)); + } + } + typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// @@ -993,6 +1061,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, case spirv::Opcode::OpTypeVector: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: return processType(opcode, operands); case spirv::Opcode::OpConstant: @@ -1015,6 +1084,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processConstantNull(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); + case spirv::Opcode::OpMemberDecorate: + return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); default: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 575d995bf45..bc0b706092c 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/raw_ostream.h" @@ -148,6 +149,11 @@ private: return emitError(loc, "unhandled decoraion for type:") << type; } + /// Process member decoration + LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberNum, + spirv::Decoration decorationType, + uint32_t value); + //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// @@ -411,6 +417,16 @@ LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( } return success(); } + +LogicalResult +Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, + spirv::Decoration decorationType, + uint32_t value) { + SmallVector<uint32_t, 4> args( + {structID, memberIndex, static_cast<uint32_t>(decorationType), value}); + return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, + args); +} } // namespace LogicalResult Serializer::processFuncOp(FuncOp op) { @@ -618,6 +634,31 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, return success(); } + if (auto structType = type.dyn_cast<spirv::StructType>()) { + bool hasLayout = structType.hasLayout(); + for (auto elementIndex : + llvm::seq<uint32_t>(0, structType.getNumElements())) { + uint32_t elementTypeID = 0; + if (failed(processType(loc, structType.getElementType(elementIndex), + elementTypeID))) { + return failure(); + } + operands.push_back(elementTypeID); + if (hasLayout) { + // Decorate each struct member with an offset + if (failed(processMemberDecoration( + resultID, elementIndex, spirv::Decoration::Offset, + static_cast<uint32_t>(structType.getOffset(elementIndex))))) { + return emitError(loc, "cannot decorate ") + << elementIndex << "-th member of : " << structType + << "with its offset"; + } + } + } + typeEnum = spirv::Opcode::OpTypeStruct; + return success(); + } + // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } |

