diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 5 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 12 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | 37 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 11 | ||||
| -rw-r--r-- | mlir/test/Dialect/SPIRV/Serialization/loop.mlir | 46 |
5 files changed, 102 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 8de2aebf0b7..1e41fa02638 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -300,7 +300,12 @@ def SPV_LoopOp : SPV_Op<"loop", [InFunctionScope]> { let regions = (region AnyRegion:$body); + let builders = [OpBuilder<"Builder *builder, OperationState &state">]; + let extraClassDeclaration = [{ + // Returns the entry block. + Block *getEntryBlock(); + // Returns the loop header block. Block *getHeaderBlock(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 3c1563ed515..9d76e56fb5f 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1442,6 +1442,13 @@ static LogicalResult verify(spirv::LoadOp loadOp) { // spv.loop //===----------------------------------------------------------------------===// +void spirv::LoopOp::build(Builder *builder, OperationState &state) { + state.addAttribute("loop_control", + builder->getI32IntegerAttr( + static_cast<uint32_t>(spirv::LoopControl::None))); + state.addRegion(); +} + static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) { // TODO(antiagainst): support loop control properly Builder builder = parser.getBuilder(); @@ -1557,6 +1564,11 @@ static LogicalResult verify(spirv::LoopOp loopOp) { return success(); } +Block *spirv::LoopOp::getEntryBlock() { + assert(!body().empty() && "op region should not be empty!"); + return &body().front(); +} + Block *spirv::LoopOp::getHeaderBlock() { assert(!body().empty() && "op region should not be empty!"); // The second block is the loop header block. diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 40b53185529..11509bb7688 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1700,9 +1700,8 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp() { // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr( - static_cast<uint32_t>(spirv::LoopControl::None)); - auto loopOp = builder.create<spirv::LoopOp>(location, control); + // TODO(antiagainst): handle loop control properly + auto loopOp = builder.create<spirv::LoopOp>(location); loopOp.addEntryAndMergeBlock(); return loopOp; @@ -1810,10 +1809,25 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { headerBlock->replaceAllUsesWith(mergeBlock); if (isLoop) { + // The loop selection/loop header block may have block arguments. Since now + // we place the selection/loop op inside the old merge block, we need to + // make sure the old merge block has the same block argument list. + assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); + for (BlockArgument *blockArg : headerBlock->getArguments()) { + mergeBlock->addArgument(blockArg->getType()); + } + + // If the loop header block has block arguments, make sure the spv.branch op + // matches. + SmallVector<Value *, 4> blockArgs; + if (!headerBlock->args_empty()) + blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; + // The loop entry block should have a unconditional branch jumping to the // loop header block. builder.setInsertionPointToEnd(&body.front()); - builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock)); + builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), + ArrayRef<Value *>(blockArgs)); } // All the blocks cloned into the SelectionOp/LoopOp's region can now be @@ -1901,16 +1915,23 @@ LogicalResult Deserializer::structurizeControlFlow() { for (const auto &info : blockMergeInfo) { auto *headerBlock = info.first; - LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); + LLVM_DEBUG(headerBlock->print(llvm::dbgs())); const auto &mergeInfo = info.second; + auto *mergeBlock = mergeInfo.mergeBlock; - auto *continueBlock = mergeInfo.continueBlock; assert(mergeBlock && "merge block cannot be nullptr"); - LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n"); + if (!mergeBlock->args_empty()) + return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); + LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); + LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); + + auto *continueBlock = mergeInfo.continueBlock; if (continueBlock) { LLVM_DEBUG(llvm::dbgs() - << "[cf] continue block " << continueBlock << "\n"); + << "[cf] continue block " << continueBlock << ":\n"); + LLVM_DEBUG(continueBlock->print(llvm::dbgs())); } if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 805a3393b0c..02134200a93 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1515,6 +1515,17 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // afterwards. encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID}); + // We omit the LoopOp's entry block and start serialization from the loop + // header block. The entry block should not contain any additional ops other + // than a single spv.Branch that jumps to the loop header block. However, + // the spv.Branch can contain additional block arguments. Those block + // arguments must come from out of the loop using implicit capture. We will + // need to query the <id> for the value sent and the <id> for the incoming + // parent block. For the latter, we need to make sure this block is + // registered. The value sent should come from the block this loop resides in. + blockIDMap[loopOp.getEntryBlock()] = + getBlockID(loopOp.getOperation()->getBlock()); + // Emit the loop header block, which dominates all other blocks, first. We // need to emit an OpLoopMerge instruction before the loop header block's // terminator. diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir index 88d9a1439b4..e89708fa4f0 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s // Single loop @@ -61,6 +61,50 @@ spv.module "Logical" "GLSL450" { // ----- +spv.module "Logical" "GLSL450" { + spv.globalVariable @GV1 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer> + spv.globalVariable @GV2 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer> + func @loop_kernel() { + %0 = spv._address_of @GV1 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer> + %1 = spv.constant 0 : i32 + %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer> + %3 = spv._address_of @GV2 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer> + %6 = spv.constant 4 : i32 + %7 = spv.constant 42 : i32 + %8 = spv.constant 2 : i32 +// CHECK: spv.Branch ^bb1(%{{.*}} : i32) +// CHECK-NEXT: ^bb1(%[[OUTARG:.*]]: i32): +// CHECK-NEXT: spv.loop { + spv.loop { +// CHECK-NEXT: spv.Branch ^bb1(%[[OUTARG]] : i32) + spv.Branch ^header(%6 : i32) +// CHECK-NEXT: ^bb1(%[[HEADARG:.*]]: i32): + ^header(%9: i32): + %10 = spv.SLessThan %9, %7 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^bb2, ^bb3 + spv.BranchConditional %10, ^body, ^merge +// CHECK-NEXT: ^bb2: // pred: ^bb1 + ^body: + %11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer> + %12 = spv.Load "StorageBuffer" %11 : f32 + %13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer> + spv.Store "StorageBuffer" %13, %12 : f32 +// CHECK: %[[ADD:.*]] = spv.IAdd + %14 = spv.IAdd %9, %8 : i32 +// CHECK-NEXT: spv.Branch ^bb1(%[[ADD]] : i32) + spv.Branch ^header(%14 : i32) +// CHECK-NEXT: ^bb3: + ^merge: +// CHECK-NEXT: spv._merge + spv._merge + } + spv.Return + } + spv.EntryPoint "GLCompute" @loop_kernel + spv.ExecutionMode @loop_kernel "LocalSize", 1, 1, 1 +} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]} + // TODO(antiagainst): re-enable this after fixing the assertion failure. // Nested loop |

