diff options
author | Mahesh Ravishankar <ravishankarm@google.com> | 2019-12-09 09:51:25 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-09 09:52:21 -0800 |
commit | 4a62019eb88f0f8fafe8f4f7ab1c984313b0b022 (patch) | |
tree | 432f1cef5f32fca9b9654bef352508d69332f63f /mlir/lib/Dialect | |
parent | 312ccb1c0f6df2fb67a7ad24ab4ce70dadbcda37 (diff) | |
download | bcm5719-llvm-4a62019eb88f0f8fafe8f4f7ab1c984313b0b022.tar.gz bcm5719-llvm-4a62019eb88f0f8fafe8f4f7ab1c984313b0b022.zip |
Add lowering for module with gpu.kernel_module attribute.
The existing GPU to SPIR-V lowering created a spv.module for every
function with gpu.kernel attribute. A better approach is to lower the
module that the function lives in (which has the attribute
gpu.kernel_module) to a spv.module operation. This better captures the
host-device separation modeled by GPU dialect and simplifies the
lowering as well.
PiperOrigin-RevId: 284574688
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 49 |
2 files changed, 42 insertions, 9 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 694a98fd075..bf17d10d808 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction( newFuncOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), llvm::None)); rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); - rewriter.replaceOp(funcOp.getOperation(), llvm::None); + rewriter.eraseOp(funcOp); // Set the attributes for argument and the function. StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 66af4305858..7061200fa65 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op, return success(); } +template <typename Ty> +static ArrayAttr +getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, + llvm::function_ref<StringRef(Ty)> stringifyFn) { + if (enumValues.empty()) { + return nullptr; + } + SmallVector<StringRef, 1> enumValStrs; + enumValStrs.reserve(enumValues.size()); + for (auto val : enumValues) { + enumValStrs.emplace_back(stringifyFn(val)); + } + return builder.getStrArrayAttr(enumValStrs); +} + template <typename EnumClass> static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser &parser, @@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) { ensureTerminator(*state.addRegion(), *builder, state.location); } +// TODO(ravishankarm): This is only here for resolving some dependency outside +// of mlir. Remove once it is done. void spirv::ModuleOp::build(Builder *builder, OperationState &state, IntegerAttr addressing_model, - IntegerAttr memory_model, ArrayAttr capabilities, - ArrayAttr extensions, - ArrayAttr extended_instruction_sets) { + IntegerAttr memory_model) { state.addAttribute("addressing_model", addressing_model); state.addAttribute("memory_model", memory_model); - if (capabilities) - state.addAttribute("capabilities", capabilities); - if (extensions) - state.addAttribute("extensions", extensions); + build(builder, state); +} + +void spirv::ModuleOp::build(Builder *builder, OperationState &state, + spirv::AddressingModel addressing_model, + spirv::MemoryModel memory_model, + ArrayRef<spirv::Capability> capabilities, + ArrayRef<spirv::Extension> extensions, + ArrayAttr extended_instruction_sets) { + state.addAttribute( + "addressing_model", + builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model))); + state.addAttribute("memory_model", builder->getI32IntegerAttr( + static_cast<int32_t>(memory_model))); + if (!capabilities.empty()) + state.addAttribute("capabilities", + getStrArrayAttrForEnumList<spirv::Capability>( + *builder, capabilities, spirv::stringifyCapability)); + if (!extensions.empty()) + state.addAttribute("extensions", + getStrArrayAttrForEnumList<spirv::Extension>( + *builder, extensions, spirv::stringifyExtension)); if (extended_instruction_sets) state.addAttribute("extended_instruction_sets", extended_instruction_sets); - ensureTerminator(*state.addRegion(), *builder, state.location); + build(builder, state); } static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { |