diff options
| author | Christian Sigg <csigg@google.com> | 2019-09-24 06:29:25 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-09-24 06:29:54 -0700 |
| commit | 74cdbf5909e57b42b6ed5b3b6eea4f76448a7d48 (patch) | |
| tree | 1afd7a89c7e53b2f622fc80c646a3c4a19c5ee89 /mlir/lib/Dialect/GPU/Transforms | |
| parent | eba6014cdc1cc1a9d9732a2e9010afde2d9d898e (diff) | |
| download | bcm5719-llvm-74cdbf5909e57b42b6ed5b3b6eea4f76448a7d48.tar.gz bcm5719-llvm-74cdbf5909e57b42b6ed5b3b6eea4f76448a7d48.zip | |
Clone called functions into nested GPU module.
PiperOrigin-RevId: 270891190
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms')
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 9bf4cf6e643..f38a2e81986 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -144,13 +144,10 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { public: void runOnModule() override { ModuleManager moduleManager(getModule()); - auto context = getModule().getContext(); - Builder builder(context); for (auto func : getModule().getOps<FuncOp>()) { // Insert just after the function. Block::iterator insertPt(func.getOperation()->getNextNode()); func.walk([&](gpu::LaunchOp op) { - // TODO(b/141098412): Handle called functions and globals. FuncOp outlinedFunc = outlineKernelFunc(op); // Potentially renames outlinedFunc to make symbol unique. @@ -164,14 +161,41 @@ public: kernelFunc.getBody().takeBody(outlinedFunc.getBody()); // Create nested module and insert kernelFunc. - auto kernelModule = ModuleOp::create(UnknownLoc::get(context)); - kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(), - builder.getUnitAttr()); - kernelModule.push_back(kernelFunc); + auto kernelModule = createKernelModule(kernelFunc, moduleManager); getModule().insert(insertPt, kernelModule); }); } } + +private: + // Returns a module containing kernelFunc and all callees (recursive). + ModuleOp createKernelModule(FuncOp kernelFunc, + const ModuleManager &parentModuleManager) { + auto context = getModule().getContext(); + auto kernelModule = ModuleOp::create(UnknownLoc::get(context)); + kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(), + UnitAttr::get(context)); + ModuleManager moduleManager(kernelModule); + + llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc}; + while (!funcsToInsert.empty()) { + FuncOp func = funcsToInsert.pop_back_val(); + moduleManager.insert(func); + + // TODO(b/141098412): Support any op with a callable interface. + func.walk([&](CallOp call) { + auto callee = call.callee(); + if (moduleManager.lookupSymbol<FuncOp>(callee)) + return; + + auto calleeFromParent = + parentModuleManager.lookupSymbol<FuncOp>(callee); + funcsToInsert.push_back(calleeFromParent.clone()); + }); + } + + return kernelModule; + } }; } // namespace |

