diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-09-11 14:02:23 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-09-11 14:02:59 -0700 |
| commit | a84bc68accc5103621df3b1661153c419ecafed7 (patch) | |
| tree | 3ae460e3c73141b047e91cf56d5b5b85784f48ca /mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | |
| parent | e15356f8edab9ce4b4edbaf2e988d6b38adb59cc (diff) | |
| download | bcm5719-llvm-a84bc68accc5103621df3b1661153c419ecafed7.tar.gz bcm5719-llvm-a84bc68accc5103621df3b1661153c419ecafed7.zip | |
[spirv] Add support for spv.loop (de)serialization
This CL adds support for serializing and deserializing spv.loop ops.
This adds support for spv.Branch and spv.BranchConditional op
(de)serialization, too, because they are needed for spv.loop.
PiperOrigin-RevId: 268536962
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 255 |
1 files changed, 231 insertions, 24 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 43a1d08cf6c..e05c0d4f8e6 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -29,6 +29,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/raw_ostream.h" @@ -248,6 +249,28 @@ private: bool isSpec = false); //===--------------------------------------------------------------------===// + // Control flow + //===--------------------------------------------------------------------===// + + uint32_t findBlockID(Block *block) const { return blockIDMap.lookup(block); } + + uint32_t assignBlockID(Block *block); + + // Processes the given `block` and emits SPIR-V instructions for all ops + // inside. `actionBeforeTerminator` is a callback that will be invoked before + // handling the terminator op. It can be used to inject the Op*Merge + // instruction if this is a SPIR-V selection/loop header block. + LogicalResult + processBlock(Block *block, + llvm::function_ref<void()> actionBeforeTerminator = nullptr); + + LogicalResult processLoopOp(spirv::LoopOp loopOp); + + LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); + + LogicalResult processBranchOp(spirv::BranchOp branchOp); + + //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// @@ -313,6 +336,9 @@ private: /// Map from FuncOps name to <id>s. llvm::StringMap<uint32_t> funcIDMap; + /// Map from blocks to their <id>s. + DenseMap<Block *, uint32_t> blockIDMap; + /// Map from results of normal operations to their <id>s. DenseMap<Value *, uint32_t> valueIDMap; }; @@ -503,8 +529,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { uint32_t resTypeID = 0; auto resultTypes = op.getType().getResults(); if (resultTypes.size() > 1) { - return emitError(op.getLoc(), - "cannot serialize function with multiple return types"); + return op.emitError("cannot serialize function with multiple return types"); } if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), @@ -539,20 +564,15 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { // Process the body. if (op.isExternal()) { - return emitError(op.getLoc(), "external function is unhandled"); + return op.emitError("external function is unhandled"); } - for (auto &b : op) { - // TODO(antiagainst): support basic blocks and control flow properly. - encodeInstructionInto(functions, spirv::Opcode::OpLabel, {getNextID()}); - for (auto &op : b) { - if (failed(processOperation(&op))) { - return failure(); - } - } + for (auto &block : op) { + if (failed(processBlock(&block))) + return failure(); } - // Insert Function End. + // Insert OpFunctionEnd. return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); } @@ -1137,6 +1157,181 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, } //===----------------------------------------------------------------------===// +// Control flow +//===----------------------------------------------------------------------===// + +uint32_t Serializer::assignBlockID(Block *block) { + assert(blockIDMap.lookup(block) == 0 && "block already has <id>"); + return blockIDMap[block] = getNextID(); +} + +LogicalResult +Serializer::processBlock(Block *block, + llvm::function_ref<void()> actionBeforeTerminator) { + auto blockID = findBlockID(block); + if (blockID == 0) { + blockID = assignBlockID(block); + } + + // Emit OpLabel for this block. + encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID}); + + // Process each op in this block except the terminator. + for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { + if (failed(processOperation(&op))) + return failure(); + } + + // Process the terminator. + if (actionBeforeTerminator) + actionBeforeTerminator(); + if (failed(processOperation(&block->back()))) + return failure(); + + return success(); +} + +namespace { +/// A pre-order depth-first vistor for processing basic blocks in a spv.loop op. +/// +/// This visitor is special tailored for spv.loop block serialization to satisfy +/// SPIR-V validation rules. It should not be used as a general depth-first +/// block visitor. +class LoopBlockVisitor { +public: + using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>; + + /// Visits the basic blocks starting from the given `headerBlock`'s successors + /// in pre-order depth-first manner and calls `blockHandler` on each block. + /// Skips handling the `headerBlock` and blocks in the `skipBlocks` list. + static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler, + ArrayRef<Block *> skipBlocks) { + return LoopBlockVisitor(blockHandler, skipBlocks) + .visitHeaderBlock(headerBlock); + } + +private: + LoopBlockVisitor(BlockHandlerType blockHandler, ArrayRef<Block *> skipBlocks) + : blockHandler(blockHandler), + doneBlocks(skipBlocks.begin(), skipBlocks.end()) {} + + LogicalResult visitHeaderBlock(Block *header) { + // Skip processing the header block. + doneBlocks.insert(header); + + for (auto *successor : header->getSuccessors()) { + if (failed(visitNormalBlock(successor))) + return failure(); + } + + return success(); + } + + LogicalResult visitNormalBlock(Block *block) { + if (doneBlocks.count(block)) + return success(); + + if (failed(blockHandler(block))) + return failure(); + doneBlocks.insert(block); + + for (auto *successor : block->getSuccessors()) { + if (failed(visitNormalBlock(successor))) + return failure(); + } + + return success(); + } + + BlockHandlerType blockHandler; + SmallPtrSet<Block *, 4> doneBlocks; +}; +} // namespace + +LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { + // SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order + // of blocks in a function must satisfy the rule that blocks appear before all + // blocks they dominate." This can be achieved by a pre-order CFG traversal + // algorithm. To make the serialization output more logical and readable to + // human, we perform depth-first CFG traversal and delay the serialization of + // the continue block and the merge block until after all other blocks have + // been processed. + + // Assign <id>s to all blocks so that branchs inside the LoopOp can resolve + // properly. We don't need to assign for the entry block, which is just for + // satisfying MLIR region's structural requirement. + auto &body = loopOp.body(); + for (Block &block : + llvm::make_range(std::next(body.begin(), 1), body.end())) { + assignBlockID(&block); + } + auto *headerBlock = loopOp.getHeaderBlock(); + auto *continueBlock = loopOp.getContinueBlock(); + auto *mergeBlock = loopOp.getMergeBlock(); + auto headerID = findBlockID(headerBlock); + auto continueID = findBlockID(continueBlock); + auto mergeID = findBlockID(mergeBlock); + + // This LoopOp is in some MLIR block with preceding and following ops. In the + // binary format, it should reside in separate SPIR-V blocks from its + // preceding and following ops. So we need to emit unconditional branches to + // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow + // afterwards. + + encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID}); + + // Emit the loop header block, which dominates all other blocks, first. We + // need to emit an OpLoopMerge instruction before the loop header block's + // terminator. + auto emitLoopMerge = [&]() { + // TODO(antiagainst): properly support loop control here + encodeInstructionInto( + functions, spirv::Opcode::OpLoopMerge, + {mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)}); + }; + if (failed(processBlock(headerBlock, emitLoopMerge))) + return failure(); + + // Process all blocks with a depth-first visitor starting from the header + // block. The loop header block, loop continue block, and loop merge block are + // skipped by this visitor and handled later in this function. + auto handleBlock = [&](Block *block) { return processBlock(block); }; + if (failed(LoopBlockVisitor::visit(headerBlock, handleBlock, + {continueBlock, mergeBlock}))) + return failure(); + + // We have handled all other blocks. Now get to the loop continue block. + if (failed(processBlock(continueBlock))) + return failure(); + + // There is nothing to do for the merge block in the loop, which just contains + // a spv._merge op, itself. But we need to have an OpLabel instruction to + // start a new SPIR-V block for ops following this LoopOp. + return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID}); +} + +LogicalResult Serializer::processBranchConditionalOp( + spirv::BranchConditionalOp condBranchOp) { + auto conditionID = findValueID(condBranchOp.condition()); + auto trueLabelID = findBlockID(condBranchOp.getTrueBlock()); + auto falseLabelID = findBlockID(condBranchOp.getFalseBlock()); + SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; + + if (auto weights = condBranchOp.branch_weights()) { + for (auto val : weights->getValue()) + arguments.push_back(val.cast<IntegerAttr>().getInt()); + } + + return encodeInstructionInto(functions, spirv::Opcode::OpBranchConditional, + arguments); +} + +LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { + return encodeInstructionInto(functions, spirv::Opcode::OpBranch, + {findBlockID(branchOp.getTarget())}); +} + +//===----------------------------------------------------------------------===// // Operation //===----------------------------------------------------------------------===// @@ -1165,29 +1360,41 @@ Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { } LogicalResult Serializer::processOperation(Operation *op) { - // First dispatch the methods that do not directly mirror an operation from - // the SPIR-V spec - if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) { - return processConstantOp(constOp); + // First dispatch the ops that do not directly mirror an instruction from + // the SPIR-V spec. + if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) { + return processAddressOfOp(addressOfOp); } - if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) { - return processSpecConstantOp(specConstOp); + if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { + return processBranchOp(branchOp); } - if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) { - return processReferenceOfOp(refOpOp); + if (auto condBranchOp = dyn_cast<spirv::BranchConditionalOp>(op)) { + return processBranchConditionalOp(condBranchOp); + } + if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) { + return processConstantOp(constOp); } if (auto fnOp = dyn_cast<FuncOp>(op)) { return processFuncOp(fnOp); } + if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { + return processGlobalVariableOp(varOp); + } + if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) { + return processLoopOp(loopOp); + } if (isa<spirv::ModuleEndOp>(op)) { return success(); } - if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { - return processGlobalVariableOp(varOp); + if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) { + return processReferenceOfOp(refOpOp); } - if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) { - return processAddressOfOp(addressOfOp); + if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) { + return processSpecConstantOp(specConstOp); } + + // Then handle all the ops that directly mirror SPIR-V instructions with + // auto-generated methods. return dispatchToAutogenSerialization(op); } |

