diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d9a1106270f..672beee56cd 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -182,21 +182,23 @@ private: builder.getUnitAttr()); ModuleManager moduleManager(kernelModule); - llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc}; - while (!funcsToInsert.empty()) { - FuncOp func = funcsToInsert.pop_back_val(); - moduleManager.insert(func); - - // TODO(b/141098412): Support any op with a callable interface. - func.walk([&](CallOp call) { - auto callee = call.callee(); - if (moduleManager.lookupSymbol<FuncOp>(callee)) - return; - - auto calleeFromParent = - parentModuleManager.lookupSymbol<FuncOp>(callee); - funcsToInsert.push_back(calleeFromParent.clone()); - }); + moduleManager.insert(kernelFunc); + + llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (Optional<SymbolTable::UseRange> symbolUses = + SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = symbolUse.getSymbolRef().getValue(); + if (moduleManager.lookupSymbol(symbolName)) + continue; + + Operation *symbolDefClone = + parentModuleManager.lookupSymbol(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + moduleManager.insert(symbolDefClone); + } + } } return kernelModule; |

