diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp')
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 22 |
1 files changed, 9 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 0f8e2253980..37f9c2e7b84 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -140,8 +140,8 @@ namespace { /// inside a nested module. It also creates an external function of the same /// name in the parent module. /// -/// The gpu.modules are intended to be compiled to a cubin blob independently in -/// a separate pass. The external functions can then be annotated with the +/// The kernel modules are intended to be compiled to a cubin blob independently +/// in a separate pass. The external functions can then be annotated with the /// symbol of the cubin accessor function. class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { public: @@ -174,19 +174,15 @@ public: } private: - // Returns a gpu.module containing kernelFunc and all callees (recursive). - gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc, - const SymbolTable &parentSymbolTable) { - // TODO: This code cannot use an OpBuilder because it must be inserted into - // a SymbolTable by the caller. SymbolTable needs to be refactored to - // prevent manual building of Ops with symbols in code using SymbolTables - // and then this needs to use the OpBuilder. + // Returns a module containing kernelFunc and all callees (recursive). + ModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc, + const SymbolTable &parentSymbolTable) { auto context = getModule().getContext(); Builder builder(context); - OperationState state(kernelFunc.getLoc(), - gpu::GPUModuleOp::getOperationName()); - gpu::GPUModuleOp::build(&builder, state, kernelFunc.getName()); - auto kernelModule = cast<gpu::GPUModuleOp>(Operation::create(state)); + auto kernelModule = + ModuleOp::create(builder.getUnknownLoc(), kernelFunc.getName()); + kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(), + builder.getUnitAttr()); SymbolTable symbolTable(kernelModule); symbolTable.insert(kernelFunc); |

