diff options
| author | MLIR Team <no-reply@google.com> | 2019-11-08 19:12:40 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-08 19:13:13 -0800 |
| commit | 9fbf52e330faa9f310855d7e4a02d48c3a1ccd41 (patch) | |
| tree | bbb2629681c67452ae0fa86c2eca35b80e09058d /mlir/lib/Dialect/GPU/Transforms | |
| parent | bcfb3d4cd6de3b535d7915972ac2af0b74378ff9 (diff) | |
| download | bcm5719-llvm-9fbf52e330faa9f310855d7e4a02d48c3a1ccd41.tar.gz bcm5719-llvm-9fbf52e330faa9f310855d7e4a02d48c3a1ccd41.zip | |
Look for SymbolRefAttr in KernelOutlining instead of hard-coding CallOp
This code should be exercised using the existing kernel outlining unit test, but
let me know if I should add a dedicated unit test using a fake call instruction
as well.
PiperOrigin-RevId: 279436321
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms')
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d9a1106270f..672beee56cd 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -182,21 +182,23 @@ private: builder.getUnitAttr()); ModuleManager moduleManager(kernelModule); - llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc}; - while (!funcsToInsert.empty()) { - FuncOp func = funcsToInsert.pop_back_val(); - moduleManager.insert(func); - - // TODO(b/141098412): Support any op with a callable interface. - func.walk([&](CallOp call) { - auto callee = call.callee(); - if (moduleManager.lookupSymbol<FuncOp>(callee)) - return; - - auto calleeFromParent = - parentModuleManager.lookupSymbol<FuncOp>(callee); - funcsToInsert.push_back(calleeFromParent.clone()); - }); + moduleManager.insert(kernelFunc); + + llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (Optional<SymbolTable::UseRange> symbolUses = + SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = symbolUse.getSymbolRef().getValue(); + if (moduleManager.lookupSymbol(symbolName)) + continue; + + Operation *symbolDefClone = + parentModuleManager.lookupSymbol(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + moduleManager.insert(symbolDefClone); + } + } } return kernelModule; |

