diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 55 |
1 files changed, 15 insertions, 40 deletions
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 2fd8cedfd63..a90cea99be4 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -63,27 +63,13 @@ private: SmallVector<int32_t, 3> workGroupSizeAsInt32; }; -/// Pattern to convert a module with gpu.kernel_module attribute to a -/// spv.module. -class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> { +/// Pattern to convert a gpu.module to a spv.module. +class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> { public: - using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering; + using SPIRVOpLowering<gpu::GPUModuleOp>::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Pattern to convert a module terminator op to a terminator of spv.module op. -// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined -// in ODS. -class KernelModuleTerminatorConversion final - : public SPIRVOpLowering<ModuleTerminatorOp> { -public: - using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering; - - PatternMatchResult - matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands, + matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override; }; @@ -284,16 +270,12 @@ KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, } //===----------------------------------------------------------------------===// -// ModuleOp with gpu.kernel_module. +// ModuleOp with gpu.module. //===----------------------------------------------------------------------===// -PatternMatchResult KernelModuleConversion::matchAndRewrite( - ModuleOp moduleOp, ArrayRef<Value> operands, +PatternMatchResult GPUModuleConversion::matchAndRewrite( + gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { - if (!moduleOp.getAttrOfType<UnitAttr>( - gpu::GPUDialect::getKernelModuleAttrName())) { - return matchFailure(); - } // TODO : Generalize this to account for different extensions, // capabilities, extended_instruction_sets, other addressing models // and memory models. @@ -302,8 +284,8 @@ PatternMatchResult KernelModuleConversion::matchAndRewrite( spirv::MemoryModel::GLSL450, spirv::Capability::Shader, spirv::Extension::SPV_KHR_storage_buffer_storage_class); // Move the region from the module op into the SPIR-V module. - Region &spvModuleRegion = spvModule.getOperation()->getRegion(0); - rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion, + Region &spvModuleRegion = spvModule.body(); + rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion, spvModuleRegion.begin()); // The spv.module build method adds a block with a terminator. Remove that // block. The terminator of the module op in the remaining block will be @@ -314,17 +296,6 @@ PatternMatchResult KernelModuleConversion::matchAndRewrite( } //===----------------------------------------------------------------------===// -// ModuleTerminatorOp for gpu.kernel_module. -//===----------------------------------------------------------------------===// - -PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( - ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp); - return matchSuccess(); -} - -//===----------------------------------------------------------------------===// // GPU return inside kernel functions to SPIR-V return. //===----------------------------------------------------------------------===// @@ -342,14 +313,18 @@ PatternMatchResult GPUReturnOpConversion::matchAndRewrite( // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// +namespace { +#include "GPUToSPIRV.cpp.inc" +} + void mlir::populateGPUToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns, ArrayRef<int64_t> workGroupSize) { + populateWithGenerated(context, &patterns); patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize); patterns.insert< - GPUReturnOpConversion, ForOpConversion, KernelModuleConversion, - KernelModuleTerminatorConversion, + GPUReturnOpConversion, ForOpConversion, GPUModuleConversion, LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>, LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>, LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>, |