summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHanhan Wang <hanchung@google.com>2019-11-12 18:58:36 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-12 18:59:15 -0800
commit85d7fb3324a6442e865c87ea766992ab096f8859 (patch)
tree6881e629ceaa3acf9c97a62d3454489dad302973
parent2be53603e9296e86ae6ef529c37053e198560f60 (diff)
downloadbcm5719-llvm-85d7fb3324a6442e865c87ea766992ab096f8859.tar.gz
bcm5719-llvm-85d7fb3324a6442e865c87ea766992ab096f8859.zip
Make VariableOp instructions be in the first block in the function.
Since VariableOp is serialized during processBlock, we add two more fields, `functionHeader` and `functionBody`, to collect instructions for a function. After all the blocks have been processed, we append them to the `functions`. Also, fix a bug in processGlobalVariableOp. The global variables should be encoded into `typesGlobalValues`. PiperOrigin-RevId: 280105366
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp117
-rw-r--r--mlir/test/Dialect/SPIRV/Serialization/constant.mlir13
-rw-r--r--mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp2
3 files changed, 111 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 02134200a93..0ff79d92ee1 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -177,6 +177,8 @@ private:
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
+ LogicalResult processVariableOp(spirv::VariableOp op);
+
/// Process a SPIR-V GlobalVariableOp
LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
@@ -374,6 +376,19 @@ private:
SmallVector<uint32_t, 0> typesGlobalValues;
SmallVector<uint32_t, 0> functions;
+ /// `functionHeader` contains all the instructions that must be in the first
+ /// block in the function, and `functionBody` contains the rest. After
+ /// processing FuncOp, the encoded instructions of a function are appended to
+ /// `functions`. An example of instructions in `functionHeader` in order:
+ /// OpFunction ...
+ /// OpFunctionParameter ...
+ /// OpFunctionParameter ...
+ /// OpLabel ...
+ /// OpVariable ...
+ /// OpVariable ...
+ SmallVector<uint32_t, 0> functionHeader;
+ SmallVector<uint32_t, 0> functionBody;
+
/// Map from type used in SPIR-V module to their <id>s.
DenseMap<Type, uint32_t> typeIDMap;
@@ -671,6 +686,7 @@ Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
LogicalResult Serializer::processFuncOp(FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
+ assert(functionHeader.empty() && functionBody.empty());
uint32_t fnTypeID = 0;
// Generate type of the function.
@@ -694,7 +710,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
// TODO : Support other function control options.
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
operands.push_back(fnTypeID);
- encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
// Add function name.
if (failed(processName(funcID, op.getName()))) {
@@ -709,7 +725,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
}
auto argValueID = getNextID();
valueIDMap[arg] = argValueID;
- encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter,
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
{argTypeID, argValueID});
}
@@ -718,9 +734,18 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return op.emitError("external function is unhandled");
}
+ // Some instructions (e.g., OpVariable) in a function must be in the first
+ // block in the function. These instructions will be put in functionHeader.
+ // Thus, we put the label in functionHeader first, and omit it from the first
+ // block.
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
+ {getOrCreateBlockID(&op.front())});
+ processBlock(&op.front(), /*omitLabel=*/true);
if (failed(visitInPrettyBlockOrder(
- &op.front(), [&](Block *block) { return processBlock(block); })))
+ &op.front(), [&](Block *block) { return processBlock(block); },
+ /*skipHeader=*/true))) {
return failure();
+ }
// There might be OpPhi instructions who have value references needing to fix.
for (auto deferredValue : deferredPhiValues) {
@@ -730,14 +755,63 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
<< " to id = " << id << '\n');
assert(id && "OpPhi references undefined value!");
for (size_t offset : deferredValue.second)
- functions[offset] = id;
+ functionBody[offset] = id;
}
deferredPhiValues.clear();
LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
<< "' --\n");
// Insert OpFunctionEnd.
- return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
+ if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd,
+ {}))) {
+ return failure();
+ }
+
+ functions.append(functionHeader.begin(), functionHeader.end());
+ functions.append(functionBody.begin(), functionBody.end());
+ functionHeader.clear();
+ functionBody.clear();
+
+ return success();
+}
+
+LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
+ SmallVector<uint32_t, 4> operands;
+ SmallVector<StringRef, 2> elidedAttrs;
+ uint32_t resultID = 0;
+ uint32_t resultTypeID = 0;
+ if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
+ return failure();
+ }
+ operands.push_back(resultTypeID);
+ resultID = getNextID();
+ valueIDMap[op.getResult()] = resultID;
+ operands.push_back(resultID);
+ auto attr = op.getAttr(spirv::attributeName<spirv::StorageClass>());
+ if (attr) {
+ operands.push_back(static_cast<uint32_t>(
+ attr.cast<IntegerAttr>().getValue().getZExtValue()));
+ }
+ elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+ for (auto arg : op.getODSOperands(0)) {
+ auto argID = getValueID(arg);
+ if (!argID) {
+ return emitError(op.getLoc(), "operand 0 has a use before def");
+ }
+ operands.push_back(argID);
+ }
+ encodeInstructionInto(functionHeader, spirv::getOpcode<spirv::VariableOp>(),
+ operands);
+ for (auto attr : op.getAttrs()) {
+ if (llvm::any_of(elidedAttrs,
+ [&](StringRef elided) { return attr.first.is(elided); })) {
+ continue;
+ }
+ if (failed(processDecoration(op.getLoc(), resultID, attr))) {
+ return failure();
+ }
+ }
+ return success();
}
LogicalResult
@@ -789,7 +863,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
elidedAttrs.push_back("initializer");
}
- if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
+ if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable,
operands))) {
elidedAttrs.push_back("initializer");
return failure();
@@ -1360,7 +1434,7 @@ Serializer::processBlock(Block *block, bool omitLabel,
<< "[block] " << block << " (id = " << blockID << ")\n");
// Emit OpLabel for this block.
- encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID});
+ encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
}
// Emit OpPhi instructions for block arguments, if any.
@@ -1431,7 +1505,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// The op generating this value hasn't been visited yet so we don't have
// an <id> assigned yet. Record this to fix up later.
LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
- deferredPhiValues[value].push_back(functions.size() + 1 +
+ deferredPhiValues[value].push_back(functionBody.size() + 1 +
phiArgs.size());
} else {
LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
@@ -1441,7 +1515,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
phiArgs.push_back(predBlockId);
}
- encodeInstructionInto(functions, spirv::Opcode::OpPhi, phiArgs);
+ encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
valueIDMap[arg] = phiID;
}
@@ -1465,7 +1539,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
auto emitSelectionMerge = [&]() {
// TODO(antiagainst): properly support loop control here
encodeInstructionInto(
- functions, spirv::Opcode::OpSelectionMerge,
+ functionBody, spirv::Opcode::OpSelectionMerge,
{mergeID, static_cast<uint32_t>(spirv::LoopControl::None)});
};
// For structured selection, we cannot have blocks in the selection construct
@@ -1489,7 +1563,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// contains a spv._merge op, itself. But we need to have an OpLabel
// instruction to start a new SPIR-V block for ops following this SelectionOp.
// The block should use the <id> for the merge block.
- return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
}
LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
@@ -1513,7 +1587,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// preceding and following ops. So we need to emit unconditional branches to
// jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
// afterwards.
- encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
+ encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
// We omit the LoopOp's entry block and start serialization from the loop
// header block. The entry block should not contain any additional ops other
@@ -1532,7 +1606,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
auto emitLoopMerge = [&]() {
// TODO(antiagainst): properly support loop control here
encodeInstructionInto(
- functions, spirv::Opcode::OpLoopMerge,
+ functionBody, spirv::Opcode::OpLoopMerge,
{mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)});
};
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
@@ -1554,7 +1628,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// a spv._merge op, itself. But we need to have an OpLabel instruction to
// start a new SPIR-V block for ops following this LoopOp. The block should
// use the <id> for the merge block.
- return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
}
LogicalResult Serializer::processBranchConditionalOp(
@@ -1569,12 +1643,12 @@ LogicalResult Serializer::processBranchConditionalOp(
arguments.push_back(val.cast<IntegerAttr>().getInt());
}
- return encodeInstructionInto(functions, spirv::Opcode::OpBranchConditional,
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
arguments);
}
LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
- return encodeInstructionInto(functions, spirv::Opcode::OpBranch,
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
{getOrCreateBlockID(branchOp.getTarget())});
}
@@ -1610,7 +1684,7 @@ LogicalResult Serializer::encodeExtensionInstruction(
extInstOperands.push_back(setID);
extInstOperands.push_back(extensionOpcode);
extInstOperands.append(std::next(operands.begin(), 2), operands.end());
- return encodeInstructionInto(functions, spirv::Opcode::OpExtInst,
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
extInstOperands);
}
@@ -1658,6 +1732,9 @@ LogicalResult Serializer::processOperation(Operation *op) {
if (auto fnOp = dyn_cast<FuncOp>(op)) {
return processFuncOp(fnOp);
}
+ if (auto varOp = dyn_cast<spirv::VariableOp>(op)) {
+ return processVariableOp(varOp);
+ }
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
return processGlobalVariableOp(varOp);
}
@@ -1736,7 +1813,7 @@ Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
operands.push_back(operand);
}
- return encodeInstructionInto(functions, spirv::Opcode::OpControlBarrier,
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
operands);
}
@@ -1783,7 +1860,7 @@ Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
operands.push_back(operand);
}
- return encodeInstructionInto(functions, spirv::Opcode::OpMemoryBarrier,
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier,
operands);
}
@@ -1814,7 +1891,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
valueIDMap[op.getResult(0)] = funcCallID;
}
- return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall,
+ return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
operands);
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir
index acfa40a06cd..953120946db 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir
@@ -165,4 +165,17 @@ spv.module "Logical" "GLSL450" {
%1 = spv.IAdd %0, %0 : i32
spv.ReturnValue %1 : i32
}
+
+ // CHECK-LABEL: @const_variable
+ func @const_variable(%arg0 : i32, %arg1 : i32) -> () {
+ // CHECK: %[[CONST:.*]] = spv.constant 5 : i32
+ // CHECK: spv.Variable init(%[[CONST]]) : !spv.ptr<i32, Function>
+ // CHECK: spv.IAdd %arg0, %arg1
+ %0 = spv.IAdd %arg0, %arg1 : i32
+ %1 = spv.constant 5 : i32
+ %2 = spv.Variable init(%1) : !spv.ptr<i32, Function>
+ %3 = spv.Load "Function" %2 : i32
+ %4 = spv.IAdd %0, %3 : i32
+ spv.Return
+ }
}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 34bfd403975..f39295a22c8 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -231,7 +231,7 @@ static void emitSerializationFunction(const Record *attrClass,
record->getValueAsInt("extendedInstOpcode"), operands);
} else {
os << formatv(" encodeInstructionInto("
- "functions, spirv::getOpcode<{0}>(), {1});\n",
+ "functionBody, spirv::getOpcode<{0}>(), {1});\n",
op.getQualCppClassName(), operands);
}
OpenPOWER on IntegriCloud