summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorMahesh Ravishankar <ravishankarm@google.com>2019-12-09 09:51:25 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-09 09:52:21 -0800
commit4a62019eb88f0f8fafe8f4f7ab1c984313b0b022 (patch)
tree432f1cef5f32fca9b9654bef352508d69332f63f /mlir/lib/Dialect
parent312ccb1c0f6df2fb67a7ad24ab4ce70dadbcda37 (diff)
downloadbcm5719-llvm-4a62019eb88f0f8fafe8f4f7ab1c984313b0b022.tar.gz
bcm5719-llvm-4a62019eb88f0f8fafe8f4f7ab1c984313b0b022.zip
Add lowering for module with gpu.kernel_module attribute.
The existing GPU to SPIR-V lowering created a spv.module for every function with gpu.kernel attribute. A better approach is to lower the module that the function lives in (which has the attribute gpu.kernel_module) to a spv.module operation. This better captures the host-device separation modeled by GPU dialect and simplifies the lowering as well. PiperOrigin-RevId: 284574688
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp49
2 files changed, 42 insertions, 9 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 694a98fd075..bf17d10d808 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction(
newFuncOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+ rewriter.eraseOp(funcOp);
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 66af4305858..7061200fa65 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op,
return success();
}
+template <typename Ty>
+static ArrayAttr
+getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
+ llvm::function_ref<StringRef(Ty)> stringifyFn) {
+ if (enumValues.empty()) {
+ return nullptr;
+ }
+ SmallVector<StringRef, 1> enumValStrs;
+ enumValStrs.reserve(enumValues.size());
+ for (auto val : enumValues) {
+ enumValStrs.emplace_back(stringifyFn(val));
+ }
+ return builder.getStrArrayAttr(enumValStrs);
+}
+
template <typename EnumClass>
static ParseResult
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
@@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
ensureTerminator(*state.addRegion(), *builder, state.location);
}
+// TODO(ravishankarm): This is only here for resolving some dependency outside
+// of mlir. Remove once it is done.
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
IntegerAttr addressing_model,
- IntegerAttr memory_model, ArrayAttr capabilities,
- ArrayAttr extensions,
- ArrayAttr extended_instruction_sets) {
+ IntegerAttr memory_model) {
state.addAttribute("addressing_model", addressing_model);
state.addAttribute("memory_model", memory_model);
- if (capabilities)
- state.addAttribute("capabilities", capabilities);
- if (extensions)
- state.addAttribute("extensions", extensions);
+ build(builder, state);
+}
+
+void spirv::ModuleOp::build(Builder *builder, OperationState &state,
+ spirv::AddressingModel addressing_model,
+ spirv::MemoryModel memory_model,
+ ArrayRef<spirv::Capability> capabilities,
+ ArrayRef<spirv::Extension> extensions,
+ ArrayAttr extended_instruction_sets) {
+ state.addAttribute(
+ "addressing_model",
+ builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
+ state.addAttribute("memory_model", builder->getI32IntegerAttr(
+ static_cast<int32_t>(memory_model)));
+ if (!capabilities.empty())
+ state.addAttribute("capabilities",
+ getStrArrayAttrForEnumList<spirv::Capability>(
+ *builder, capabilities, spirv::stringifyCapability));
+ if (!extensions.empty())
+ state.addAttribute("extensions",
+ getStrArrayAttrForEnumList<spirv::Extension>(
+ *builder, extensions, spirv::stringifyExtension));
if (extended_instruction_sets)
state.addAttribute("extended_instruction_sets", extended_instruction_sets);
- ensureTerminator(*state.addRegion(), *builder, state.location);
+ build(builder, state);
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
OpenPOWER on IntegriCloud