summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/GPU/Transforms
diff options
context:
space:
mode:
authorChristian Sigg <csigg@google.com>2019-09-24 06:29:25 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-09-24 06:29:54 -0700
commit74cdbf5909e57b42b6ed5b3b6eea4f76448a7d48 (patch)
tree1afd7a89c7e53b2f622fc80c646a3c4a19c5ee89 /mlir/lib/Dialect/GPU/Transforms
parenteba6014cdc1cc1a9d9732a2e9010afde2d9d898e (diff)
downloadbcm5719-llvm-74cdbf5909e57b42b6ed5b3b6eea4f76448a7d48.tar.gz
bcm5719-llvm-74cdbf5909e57b42b6ed5b3b6eea4f76448a7d48.zip
Clone called functions into nested GPU module.
PiperOrigin-RevId: 270891190
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms')
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp38
1 files changed, 31 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 9bf4cf6e643..f38a2e81986 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -144,13 +144,10 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
ModuleManager moduleManager(getModule());
- auto context = getModule().getContext();
- Builder builder(context);
for (auto func : getModule().getOps<FuncOp>()) {
// Insert just after the function.
Block::iterator insertPt(func.getOperation()->getNextNode());
func.walk([&](gpu::LaunchOp op) {
- // TODO(b/141098412): Handle called functions and globals.
FuncOp outlinedFunc = outlineKernelFunc(op);
// Potentially renames outlinedFunc to make symbol unique.
@@ -164,14 +161,41 @@ public:
kernelFunc.getBody().takeBody(outlinedFunc.getBody());
// Create nested module and insert kernelFunc.
- auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
- kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
- builder.getUnitAttr());
- kernelModule.push_back(kernelFunc);
+ auto kernelModule = createKernelModule(kernelFunc, moduleManager);
getModule().insert(insertPt, kernelModule);
});
}
}
+
+private:
+ // Returns a module containing kernelFunc and all callees (recursive).
+ ModuleOp createKernelModule(FuncOp kernelFunc,
+ const ModuleManager &parentModuleManager) {
+ auto context = getModule().getContext();
+ auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
+ kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
+ UnitAttr::get(context));
+ ModuleManager moduleManager(kernelModule);
+
+ llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
+ while (!funcsToInsert.empty()) {
+ FuncOp func = funcsToInsert.pop_back_val();
+ moduleManager.insert(func);
+
+ // TODO(b/141098412): Support any op with a callable interface.
+ func.walk([&](CallOp call) {
+ auto callee = call.callee();
+ if (moduleManager.lookupSymbol<FuncOp>(callee))
+ return;
+
+ auto calleeFromParent =
+ parentModuleManager.lookupSymbol<FuncOp>(callee);
+ funcsToInsert.push_back(calleeFromParent.clone());
+ });
+ }
+
+ return kernelModule;
+ }
};
} // namespace
OpenPOWER on IntegriCloud