summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp')
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp21
1 files changed, 10 insertions, 11 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e2b1e0e533c..84bc7ff1d5f 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -200,7 +200,7 @@ private:
auto type = operand.getType().cast<LLVM::LLVMType>();
// Create shared memory array to store the warp reduction.
- auto module = operand.getDefiningOp()->getParentOfType<ModuleOp>();
+ auto module = operand.getDefiningOp()->getParentOfType<gpu::GPUModuleOp>();
assert(module && "op must belong to a module");
Value sharedMemPtr =
createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
@@ -391,10 +391,10 @@ private:
}
/// Creates a global array stored in shared memory.
- Value createSharedMemoryArray(Location loc, ModuleOp module,
+ Value createSharedMemoryArray(Location loc, gpu::GPUModuleOp module,
LLVM::LLVMType elementType, int numElements,
ConversionPatternRewriter &rewriter) const {
- OpBuilder builder(module.getBodyRegion());
+ OpBuilder builder(module.body());
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
StringRef name = "reduce_buffer";
@@ -699,13 +699,11 @@ struct GPUReturnOpLowering : public LLVMOpLowering {
///
/// This pass only handles device code and is not meant to be run on GPU host
/// code.
-class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
+class LowerGpuOpsToNVVMOpsPass
+ : public OperationPass<LowerGpuOpsToNVVMOpsPass, gpu::GPUModuleOp> {
public:
- void runOnModule() override {
- ModuleOp m = getModule();
- if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName()))
- return;
-
+ void runOnOperation() override {
+ gpu::GPUModuleOp m = getOperation();
OwningRewritePatternList patterns;
NVVMTypeConverter converter(m.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
@@ -718,7 +716,7 @@ public:
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<NVVM::NVVMDialect>();
// TODO(csigg): Remove once we support replacing non-root ops.
- target.addLegalOp<gpu::YieldOp>();
+ target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
signalPassFailure();
}
@@ -750,7 +748,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
"__nv_exp");
}
-std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() {
+std::unique_ptr<OpPassBase<gpu::GPUModuleOp>>
+mlir::createLowerGpuOpsToNVVMOpsPass() {
return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
}
OpenPOWER on IntegriCloud