summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp32
1 files changed, 17 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d9a1106270f..672beee56cd 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -182,21 +182,23 @@ private:
builder.getUnitAttr());
ModuleManager moduleManager(kernelModule);
- llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
- while (!funcsToInsert.empty()) {
- FuncOp func = funcsToInsert.pop_back_val();
- moduleManager.insert(func);
-
- // TODO(b/141098412): Support any op with a callable interface.
- func.walk([&](CallOp call) {
- auto callee = call.callee();
- if (moduleManager.lookupSymbol<FuncOp>(callee))
- return;
-
- auto calleeFromParent =
- parentModuleManager.lookupSymbol<FuncOp>(callee);
- funcsToInsert.push_back(calleeFromParent.clone());
- });
+ moduleManager.insert(kernelFunc);
+
+ llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
+ while (!symbolDefWorklist.empty()) {
+ if (Optional<SymbolTable::UseRange> symbolUses =
+ SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
+ for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ StringRef symbolName = symbolUse.getSymbolRef().getValue();
+ if (moduleManager.lookupSymbol(symbolName))
+ continue;
+
+ Operation *symbolDefClone =
+ parentModuleManager.lookupSymbol(symbolName)->clone();
+ symbolDefWorklist.push_back(symbolDefClone);
+ moduleManager.insert(symbolDefClone);
+ }
+ }
}
return kernelModule;
OpenPOWER on IntegriCloud