summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp')
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp22
1 files changed, 9 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 0f8e2253980..37f9c2e7b84 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -140,8 +140,8 @@ namespace {
/// inside a nested module. It also creates an external function of the same
/// name in the parent module.
///
-/// The gpu.modules are intended to be compiled to a cubin blob independently in
-/// a separate pass. The external functions can then be annotated with the
+/// The kernel modules are intended to be compiled to a cubin blob independently
+/// in a separate pass. The external functions can then be annotated with the
/// symbol of the cubin accessor function.
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
@@ -174,19 +174,15 @@ public:
}
private:
- // Returns a gpu.module containing kernelFunc and all callees (recursive).
- gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc,
- const SymbolTable &parentSymbolTable) {
- // TODO: This code cannot use an OpBuilder because it must be inserted into
- // a SymbolTable by the caller. SymbolTable needs to be refactored to
- // prevent manual building of Ops with symbols in code using SymbolTables
- // and then this needs to use the OpBuilder.
+ // Returns a module containing kernelFunc and all callees (recursive).
+ ModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc,
+ const SymbolTable &parentSymbolTable) {
auto context = getModule().getContext();
Builder builder(context);
- OperationState state(kernelFunc.getLoc(),
- gpu::GPUModuleOp::getOperationName());
- gpu::GPUModuleOp::build(&builder, state, kernelFunc.getName());
- auto kernelModule = cast<gpu::GPUModuleOp>(Operation::create(state));
+ auto kernelModule =
+ ModuleOp::create(builder.getUnknownLoc(), kernelFunc.getName());
+ kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
+ builder.getUnitAttr());
SymbolTable symbolTable(kernelModule);
symbolTable.insert(kernelFunc);
OpenPOWER on IntegriCloud