diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index e2b1e0e533c..84bc7ff1d5f 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -200,7 +200,7 @@ private: auto type = operand.getType().cast<LLVM::LLVMType>(); // Create shared memory array to store the warp reduction. - auto module = operand.getDefiningOp()->getParentOfType<ModuleOp>(); + auto module = operand.getDefiningOp()->getParentOfType<gpu::GPUModuleOp>(); assert(module && "op must belong to a module"); Value sharedMemPtr = createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); @@ -391,10 +391,10 @@ private: } /// Creates a global array stored in shared memory. - Value createSharedMemoryArray(Location loc, ModuleOp module, + Value createSharedMemoryArray(Location loc, gpu::GPUModuleOp module, LLVM::LLVMType elementType, int numElements, ConversionPatternRewriter &rewriter) const { - OpBuilder builder(module.getBodyRegion()); + OpBuilder builder(module.body()); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); StringRef name = "reduce_buffer"; @@ -699,13 +699,11 @@ struct GPUReturnOpLowering : public LLVMOpLowering { /// /// This pass only handles device code and is not meant to be run on GPU host /// code. -class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> { +class LowerGpuOpsToNVVMOpsPass + : public OperationPass<LowerGpuOpsToNVVMOpsPass, gpu::GPUModuleOp> { public: - void runOnModule() override { - ModuleOp m = getModule(); - if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName())) - return; - + void runOnOperation() override { + gpu::GPUModuleOp m = getOperation(); OwningRewritePatternList patterns; NVVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); @@ -718,7 +716,7 @@ public: target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<NVVM::NVVMDialect>(); // TODO(csigg): Remove once we support replacing non-root ops. - target.addLegalOp<gpu::YieldOp>(); + target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>(); if (failed(applyPartialConversion(m, target, patterns, &converter))) signalPassFailure(); } @@ -750,7 +748,8 @@ void mlir::populateGpuToNVVMConversionPatterns( "__nv_exp"); } -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() { +std::unique_ptr<OpPassBase<gpu::GPUModuleOp>> +mlir::createLowerGpuOpsToNVVMOpsPass() { return std::make_unique<LowerGpuOpsToNVVMOpsPass>(); } |