diff options
| author | Mahesh Ravishankar <ravishankarm@google.com> | 2019-07-20 18:11:39 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-07-20 18:12:05 -0700 |
| commit | 2fb53e65ab461c673762de5e3b649cc63c1f84af (patch) | |
| tree | 8aa4468c403a636eb210698895e0f4ab4d9805c6 /mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | |
| parent | a47704e1e17bed5bacca57dfccea557a9b78c8dc (diff) | |
| download | bcm5719-llvm-2fb53e65ab461c673762de5e3b649cc63c1f84af.tar.gz bcm5719-llvm-2fb53e65ab461c673762de5e3b649cc63c1f84af.zip | |
Add (de)serialization of EntryPointOp and ExecutionModeOp
Since the serialization of EntryPointOp contains the name of the
function as well, the function serialization emits the function name
using OpName instruction, which is used during deserialization to get
the correct function name.
PiperOrigin-RevId: 259158784
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 137 |
1 files changed, 108 insertions, 29 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 36acf13b520..18a2e10bcd4 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -46,6 +46,15 @@ static inline void buildInstruction(spirv::Opcode op, } } +static inline void encodeStringLiteral(StringRef literal, + SmallVectorImpl<uint32_t> &buffer) { + // Encoding is the literal + null termination + auto encodingSize = literal.size() / 4 + 1; + auto bufferStartSize = buffer.size(); + buffer.resize(bufferStartSize + encodingSize, 0); + std::memcpy(buffer.data() + bufferStartSize, literal.data(), literal.size()); +} + namespace { /// A SPIR-V module serializer. @@ -84,9 +93,11 @@ private: LogicalResult processMemoryModel(); - Optional<uint32_t> findFunctionID(Operation *op) const { - auto it = funcIDMap.find(op); - return it != funcIDMap.end() ? it->second : Optional<uint32_t>(); + // It is illegal to use <id> 0 for SSA value in SPIR-V serialization. The + // method uses that to check if the function is defined in the serialized + // binary or not. + uint32_t findFunctionID(StringRef fnName) const { + return funcIDMap.lookup(fnName); } /// Processes a SPIR-V function op. @@ -96,10 +107,10 @@ private: // Types //===--------------------------------------------------------------------===// - Optional<uint32_t> findTypeID(Type type) const { - auto it = typeIDMap.find(type); - return it != typeIDMap.end() ? it->second : Optional<uint32_t>(); - } + // It is illegal to use <id> 0 for SSA value in SPIR-V serialization. The + // method uses that to check if the type is defined in the serialized binary + // or not. + uint32_t findTypeID(Type type) const { return typeIDMap.lookup(type); } Type voidType() { return mlir::NoneType::get(module.getContext()); } @@ -123,10 +134,9 @@ private: // Operations //===--------------------------------------------------------------------===// - Optional<uint32_t> findValueID(Value *val) const { - auto it = valueIDMap.find(val); - return it != valueIDMap.end() ? it->second : Optional<uint32_t>(); - } + // It is illegal to use <id> 0 for SSA value in SPIR-V serialization. The + // method uses that to check if `val` has a corresponding <id> + uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); } /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); @@ -166,17 +176,18 @@ private: SmallVector<uint32_t, 0> entryPoints; SmallVector<uint32_t, 4> executionModes; // TODO(antiagainst): debug instructions + SmallVector<uint32_t, 0> names; SmallVector<uint32_t, 0> decorations; SmallVector<uint32_t, 0> typesGlobalValues; SmallVector<uint32_t, 0> functions; - // Map from type used in SPIR-V module to their IDs + // Map from type used in SPIR-V module to their <id>s DenseMap<Type, uint32_t> typeIDMap; - // Map from FuncOps to IDs - DenseMap<Operation *, uint32_t> funcIDMap; + // Map from FuncOps name to <id>s. + llvm::StringMap<uint32_t> funcIDMap; - // Map from Value to Ids + // Map from Value to Ids. DenseMap<Value *, uint32_t> valueIDMap; }; } // namespace @@ -216,6 +227,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); + binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); @@ -250,7 +262,7 @@ LogicalResult Serializer::processHeader() { header.push_back(spirv::kMagicNumber); header.push_back((kMajorVersion << 16) | (kMinorVersion << 8)); header.push_back(kGeneratorNumber); - header.push_back(nextID); // ID bound + header.push_back(nextID); // <id> bound header.push_back(0); // Schema (reserved word) return success(); } @@ -265,10 +277,10 @@ LogicalResult Serializer::processMemoryModel() { LogicalResult Serializer::processFuncOp(FuncOp op) { uint32_t fnTypeID = 0; - // Generate type of the function + // Generate type of the function. processType(op.getLoc(), op.getType(), fnTypeID); - // Add the function definition + // Add the function definition. SmallVector<uint32_t, 4> operands; uint32_t resTypeID = 0; auto resultTypes = op.getType().getResults(); @@ -283,14 +295,20 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { } operands.push_back(resTypeID); auto funcID = getNextID(); - funcIDMap[op.getOperation()] = funcID; + funcIDMap[op.getName()] = funcID; operands.push_back(funcID); - // TODO : Support other function control options + // TODO : Support other function control options. operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None)); operands.push_back(fnTypeID); buildInstruction(spirv::Opcode::OpFunction, operands, functions); - // Declare the parameters + // Add function name. + SmallVector<uint32_t, 4> nameOperands; + nameOperands.push_back(funcID); + encodeStringLiteral(op.getName(), nameOperands); + buildInstruction(spirv::Opcode::OpName, nameOperands, names); + + // Declare the parameters. for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) { @@ -302,18 +320,20 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { {argTypeID, argValueID}, functions); } - // Process the body + // Process the body. if (op.isExternal()) { return emitError(op.getLoc(), "external function is unhandled"); } - for (auto &b : op) - for (auto &op : b) + for (auto &b : op) { + for (auto &op : b) { if (failed(processOperation(&op))) { return failure(); } + } + } - // Insert Function End + // Insert Function End. buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions); return success(); @@ -325,9 +345,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { - auto id = findTypeID(type); - if (id) { - typeID = id.getValue(); + typeID = findTypeID(type); + if (typeID) { return success(); } typeID = getNextID(); @@ -366,7 +385,7 @@ Serializer::processBasicType(Location loc, Type type, spirv::Opcode &typeEnum, operands.push_back(pointeeTypeID); return success(); } - /// TODO(ravishankarm) : Handle other types + // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization : ") << type; } @@ -410,6 +429,66 @@ LogicalResult Serializer::processOperation(Operation *op) { } namespace { +template <> +LogicalResult +Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { + SmallVector<uint32_t, 4> operands; + // Add the ExectionModel. + operands.push_back(static_cast<uint32_t>(op.execution_model())); + // Add the function <id>. + auto funcID = findFunctionID(op.fn()); + if (!funcID) { + return op.emitError("missing <id> for function ") + << op.fn() + << "; function needs to be defined before spv.EntryPoint is " + "serialized"; + } + operands.push_back(funcID); + // Add the name of the function. + encodeStringLiteral(op.fn(), operands); + + // Add the interface values. + for (auto val : op.interface()) { + auto id = findValueID(val); + if (!id) { + return op.emitError("referencing unintialized variable <id>. " + "spv.EntryPoint is at the end of spv.module. All " + "referenced variables should already be defined"); + } + operands.push_back(id); + } + buildInstruction(spirv::Opcode::OpEntryPoint, operands, entryPoints); + return success(); +} + +template <> +LogicalResult +Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { + SmallVector<uint32_t, 4> operands; + // Add the function <id>. + auto funcID = findFunctionID(op.fn()); + if (!funcID) { + return op.emitError("missing <id> for function ") + << op.fn() + << "; function needs to be serialized before ExecutionModeOp is " + "serialized"; + } + operands.push_back(funcID); + // Add the ExecutionMode. + operands.push_back(static_cast<uint32_t>(op.execution_mode())); + + // Serialize values if any. + auto values = op.values(); + if (values) { + for (auto &intVal : values.getValue()) { + operands.push_back(static_cast<uint32_t>( + intVal.cast<IntegerAttr>().getValue().getZExtValue())); + } + } + buildInstruction(spirv::Opcode::OpExecutionMode, operands, executionModes); + return success(); +} + // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various processOpImpl specializations. #define GET_SERIALIZATION_FNS |

