diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-10-02 11:00:50 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-10-02 11:01:57 -0700 |
| commit | f294e0e513464b97ae1bb2f9532979f8698c441e (patch) | |
| tree | 6f2589c45808b090147787ede8149cc0cfb52086 /mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | |
| parent | 088f4c502f9fcba5f39d255ff6cdcf2ab050b201 (diff) | |
| download | bcm5719-llvm-f294e0e513464b97ae1bb2f9532979f8698c441e.tar.gz bcm5719-llvm-f294e0e513464b97ae1bb2f9532979f8698c441e.zip | |
[spirv] Add support for spv.selection
Similar to spv.loop, spv.selection is another op for modelling
SPIR-V structured control flow. It covers both OpBranchConditional
and OpSwitch with OpSelectionMerge.
Instead of having a `spv.SelectionMerge` op to directly model
selection merge instruction for indicating the merge target,
we use regions to delimit the boundary of the selection: the
merge target is the next op following the `spv.selection` op.
This way it's easier to discover all blocks belonging to
the selection and it plays nicer with the MLIR system.
PiperOrigin-RevId: 272475006
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 90 |
1 files changed, 71 insertions, 19 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 58aebddf29d..445f02c3886 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -250,6 +250,8 @@ private: processBlock(Block *block, llvm::function_ref<void()> actionBeforeTerminator = nullptr); + LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); + LogicalResult processLoopOp(spirv::LoopOp loopOp); LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); @@ -1220,10 +1222,18 @@ Serializer::processBlock(Block *block, namespace { /// A pre-order depth-first vistor for processing basic blocks in a spv.loop op. /// -/// This visitor is special tailored for spv.loop block serialization to satisfy -/// SPIR-V validation rules. It should not be used as a general depth-first -/// block visitor. -class LoopBlockVisitor { +/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order +/// of blocks in a function must satisfy the rule that blocks appear before all +/// blocks they dominate." This can be achieved by a pre-order CFG traversal +/// algorithm. To make the serialization output more logical and readable to +/// human, we perform depth-first CFG traversal and delay the serialization of +/// the merge block (and the continue block) until after all other blocks have +/// been processed. +/// +/// This visitor is special tailored for spv.selection or spv.loop block +/// serialization to satisfy SPIR-V validation rules. It should not be used +/// as a general depth-first block visitor. +class ControlFlowBlockVisitor { public: using BlockHandlerType = llvm::function_ref<LogicalResult(Block *)>; @@ -1232,12 +1242,13 @@ public: /// Skips handling the `headerBlock` and blocks in the `skipBlocks` list. static LogicalResult visit(Block *headerBlock, BlockHandlerType blockHandler, ArrayRef<Block *> skipBlocks) { - return LoopBlockVisitor(blockHandler, skipBlocks) + return ControlFlowBlockVisitor(blockHandler, skipBlocks) .visitHeaderBlock(headerBlock); } private: - LoopBlockVisitor(BlockHandlerType blockHandler, ArrayRef<Block *> skipBlocks) + ControlFlowBlockVisitor(BlockHandlerType blockHandler, + ArrayRef<Block *> skipBlocks) : blockHandler(blockHandler), doneBlocks(skipBlocks.begin(), skipBlocks.end()) {} @@ -1274,16 +1285,54 @@ private: }; } // namespace +LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { + // Assign <id>s to all blocks so that branches inside the SelectionOp can + // resolve properly. + auto &body = selectionOp.body(); + for (Block &block : body) + assignBlockID(&block); + + auto *headerBlock = selectionOp.getHeaderBlock(); + auto *mergeBlock = selectionOp.getMergeBlock(); + auto headerID = findBlockID(headerBlock); + auto mergeID = findBlockID(mergeBlock); + + // This selection is in some MLIR block with preceding and following ops. In + // the binary format, it should reside in separate SPIR-V blocks from its + // preceding and following ops. So we need to emit unconditional branches to + // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow + // afterwards. + encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID}); + + // Emit the selection header block, which dominates all other blocks, first. + // We need to emit an OpSelectionMerge instruction before the loop header + // block's terminator. + auto emitSelectionMerge = [&]() { + // TODO(antiagainst): properly support loop control here + encodeInstructionInto( + functions, spirv::Opcode::OpSelectionMerge, + {mergeID, static_cast<uint32_t>(spirv::LoopControl::None)}); + }; + if (failed(processBlock(headerBlock, emitSelectionMerge))) + return failure(); + + // Process all blocks with a depth-first visitor starting from the header + // block. The selection header block and merge block are skipped by this + // visitor. + auto handleBlock = [&](Block *block) { return processBlock(block); }; + if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock, + {mergeBlock}))) + return failure(); + + // There is nothing to do for the merge block in the selection, which just + // contains a spv._merge op, itself. But we need to have an OpLabel + // instruction to start a new SPIR-V block for ops following this SelectionOp. + // The block should use the <id> for the merge block. + return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID}); +} + LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { - // SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order - // of blocks in a function must satisfy the rule that blocks appear before all - // blocks they dominate." This can be achieved by a pre-order CFG traversal - // algorithm. To make the serialization output more logical and readable to - // human, we perform depth-first CFG traversal and delay the serialization of - // the continue block and the merge block until after all other blocks have - // been processed. - - // Assign <id>s to all blocks so that branchs inside the LoopOp can resolve + // Assign <id>s to all blocks so that branches inside the LoopOp can resolve // properly. We don't need to assign for the entry block, which is just for // satisfying MLIR region's structural requirement. auto &body = loopOp.body(); @@ -1303,7 +1352,6 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // preceding and following ops. So we need to emit unconditional branches to // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow // afterwards. - encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID}); // Emit the loop header block, which dominates all other blocks, first. We @@ -1322,8 +1370,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // block. The loop header block, loop continue block, and loop merge block are // skipped by this visitor and handled later in this function. auto handleBlock = [&](Block *block) { return processBlock(block); }; - if (failed(LoopBlockVisitor::visit(headerBlock, handleBlock, - {continueBlock, mergeBlock}))) + if (failed(ControlFlowBlockVisitor::visit(headerBlock, handleBlock, + {continueBlock, mergeBlock}))) return failure(); // We have handled all other blocks. Now get to the loop continue block. @@ -1332,7 +1380,8 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // There is nothing to do for the merge block in the loop, which just contains // a spv._merge op, itself. But we need to have an OpLabel instruction to - // start a new SPIR-V block for ops following this LoopOp. + // start a new SPIR-V block for ops following this LoopOp. The block should + // use the <id> for the merge block. return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID}); } @@ -1438,6 +1487,9 @@ LogicalResult Serializer::processOperation(Operation *op) { if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { return processGlobalVariableOp(varOp); } + if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) { + return processSelectionOp(selectionOp); + } if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) { return processLoopOp(loopOp); } |

