summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td5
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp37
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp11
-rw-r--r--mlir/test/Dialect/SPIRV/Serialization/loop.mlir46
5 files changed, 102 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 8de2aebf0b7..1e41fa02638 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -300,7 +300,12 @@ def SPV_LoopOp : SPV_Op<"loop", [InFunctionScope]> {
let regions = (region AnyRegion:$body);
+ let builders = [OpBuilder<"Builder *builder, OperationState &state">];
+
let extraClassDeclaration = [{
+ // Returns the entry block.
+ Block *getEntryBlock();
+
// Returns the loop header block.
Block *getHeaderBlock();
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 3c1563ed515..9d76e56fb5f 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1442,6 +1442,13 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
// spv.loop
//===----------------------------------------------------------------------===//
+void spirv::LoopOp::build(Builder *builder, OperationState &state) {
+ state.addAttribute("loop_control",
+ builder->getI32IntegerAttr(
+ static_cast<uint32_t>(spirv::LoopControl::None)));
+ state.addRegion();
+}
+
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
// TODO(antiagainst): support loop control properly
Builder builder = parser.getBuilder();
@@ -1557,6 +1564,11 @@ static LogicalResult verify(spirv::LoopOp loopOp) {
return success();
}
+Block *spirv::LoopOp::getEntryBlock() {
+ assert(!body().empty() && "op region should not be empty!");
+ return &body().front();
+}
+
Block *spirv::LoopOp::getHeaderBlock() {
assert(!body().empty() && "op region should not be empty!");
// The second block is the loop header block.
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 40b53185529..11509bb7688 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1700,9 +1700,8 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
// merge block so that the newly created LoopOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
- auto control = builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::LoopControl::None));
- auto loopOp = builder.create<spirv::LoopOp>(location, control);
+ // TODO(antiagainst): handle loop control properly
+ auto loopOp = builder.create<spirv::LoopOp>(location);
loopOp.addEntryAndMergeBlock();
return loopOp;
@@ -1810,10 +1809,25 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
headerBlock->replaceAllUsesWith(mergeBlock);
if (isLoop) {
+ // The loop selection/loop header block may have block arguments. Since now
+ // we place the selection/loop op inside the old merge block, we need to
+ // make sure the old merge block has the same block argument list.
+ assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
+ for (BlockArgument *blockArg : headerBlock->getArguments()) {
+ mergeBlock->addArgument(blockArg->getType());
+ }
+
+ // If the loop header block has block arguments, make sure the spv.branch op
+ // matches.
+ SmallVector<Value *, 4> blockArgs;
+ if (!headerBlock->args_empty())
+ blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
+
// The loop entry block should have a unconditional branch jumping to the
// loop header block.
builder.setInsertionPointToEnd(&body.front());
- builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock));
+ builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
+ ArrayRef<Value *>(blockArgs));
}
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
@@ -1901,16 +1915,23 @@ LogicalResult Deserializer::structurizeControlFlow() {
for (const auto &info : blockMergeInfo) {
auto *headerBlock = info.first;
- LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n");
+ LLVM_DEBUG(headerBlock->print(llvm::dbgs()));
const auto &mergeInfo = info.second;
+
auto *mergeBlock = mergeInfo.mergeBlock;
- auto *continueBlock = mergeInfo.continueBlock;
assert(mergeBlock && "merge block cannot be nullptr");
- LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n");
+ if (!mergeBlock->args_empty())
+ return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
+ LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n");
+ LLVM_DEBUG(mergeBlock->print(llvm::dbgs()));
+
+ auto *continueBlock = mergeInfo.continueBlock;
if (continueBlock) {
LLVM_DEBUG(llvm::dbgs()
- << "[cf] continue block " << continueBlock << "\n");
+ << "[cf] continue block " << continueBlock << ":\n");
+ LLVM_DEBUG(continueBlock->print(llvm::dbgs()));
}
if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock,
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 805a3393b0c..02134200a93 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1515,6 +1515,17 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// afterwards.
encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
+ // We omit the LoopOp's entry block and start serialization from the loop
+ // header block. The entry block should not contain any additional ops other
+ // than a single spv.Branch that jumps to the loop header block. However,
+ // the spv.Branch can contain additional block arguments. Those block
+ // arguments must come from out of the loop using implicit capture. We will
+ // need to query the <id> for the value sent and the <id> for the incoming
+ // parent block. For the latter, we need to make sure this block is
+ // registered. The value sent should come from the block this loop resides in.
+ blockIDMap[loopOp.getEntryBlock()] =
+ getBlockID(loopOp.getOperation()->getBlock());
+
// Emit the loop header block, which dominates all other blocks, first. We
// need to emit an OpLoopMerge instruction before the loop header block's
// terminator.
diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
index 88d9a1439b4..e89708fa4f0 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
// Single loop
@@ -61,6 +61,50 @@ spv.module "Logical" "GLSL450" {
// -----
+spv.module "Logical" "GLSL450" {
+ spv.globalVariable @GV1 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ spv.globalVariable @GV2 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ func @loop_kernel() {
+ %0 = spv._address_of @GV1 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ %1 = spv.constant 0 : i32
+ %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ %3 = spv._address_of @GV2 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ %5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ %6 = spv.constant 4 : i32
+ %7 = spv.constant 42 : i32
+ %8 = spv.constant 2 : i32
+// CHECK: spv.Branch ^bb1(%{{.*}} : i32)
+// CHECK-NEXT: ^bb1(%[[OUTARG:.*]]: i32):
+// CHECK-NEXT: spv.loop {
+ spv.loop {
+// CHECK-NEXT: spv.Branch ^bb1(%[[OUTARG]] : i32)
+ spv.Branch ^header(%6 : i32)
+// CHECK-NEXT: ^bb1(%[[HEADARG:.*]]: i32):
+ ^header(%9: i32):
+ %10 = spv.SLessThan %9, %7 : i32
+// CHECK: spv.BranchConditional %{{.*}}, ^bb2, ^bb3
+ spv.BranchConditional %10, ^body, ^merge
+// CHECK-NEXT: ^bb2: // pred: ^bb1
+ ^body:
+ %11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer>
+ %12 = spv.Load "StorageBuffer" %11 : f32
+ %13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer>
+ spv.Store "StorageBuffer" %13, %12 : f32
+// CHECK: %[[ADD:.*]] = spv.IAdd
+ %14 = spv.IAdd %9, %8 : i32
+// CHECK-NEXT: spv.Branch ^bb1(%[[ADD]] : i32)
+ spv.Branch ^header(%14 : i32)
+// CHECK-NEXT: ^bb3:
+ ^merge:
+// CHECK-NEXT: spv._merge
+ spv._merge
+ }
+ spv.Return
+ }
+ spv.EntryPoint "GLCompute" @loop_kernel
+ spv.ExecutionMode @loop_kernel "LocalSize", 1, 1, 1
+} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
+
// TODO(antiagainst): re-enable this after fixing the assertion failure.
// Nested loop
OpenPOWER on IntegriCloud