summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/GPU/Transforms
diff options
context:
space:
mode:
authorMLIR Team <no-reply@google.com>2019-11-08 19:12:40 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-08 19:13:13 -0800
commit9fbf52e330faa9f310855d7e4a02d48c3a1ccd41 (patch)
treebbb2629681c67452ae0fa86c2eca35b80e09058d /mlir/lib/Dialect/GPU/Transforms
parentbcfb3d4cd6de3b535d7915972ac2af0b74378ff9 (diff)
downloadbcm5719-llvm-9fbf52e330faa9f310855d7e4a02d48c3a1ccd41.tar.gz
bcm5719-llvm-9fbf52e330faa9f310855d7e4a02d48c3a1ccd41.zip
Look for SymbolRefAttr in KernelOutlining instead of hard-coding CallOp
This code should be exercised using the existing kernel outlining unit test, but let me know if I should add a dedicated unit test using a fake call instruction as well. PiperOrigin-RevId: 279436321
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms')
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp32
1 files changed, 17 insertions, 15 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d9a1106270f..672beee56cd 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -182,21 +182,23 @@ private:
builder.getUnitAttr());
ModuleManager moduleManager(kernelModule);
- llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
- while (!funcsToInsert.empty()) {
- FuncOp func = funcsToInsert.pop_back_val();
- moduleManager.insert(func);
-
- // TODO(b/141098412): Support any op with a callable interface.
- func.walk([&](CallOp call) {
- auto callee = call.callee();
- if (moduleManager.lookupSymbol<FuncOp>(callee))
- return;
-
- auto calleeFromParent =
- parentModuleManager.lookupSymbol<FuncOp>(callee);
- funcsToInsert.push_back(calleeFromParent.clone());
- });
+ moduleManager.insert(kernelFunc);
+
+ llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
+ while (!symbolDefWorklist.empty()) {
+ if (Optional<SymbolTable::UseRange> symbolUses =
+ SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
+ for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ StringRef symbolName = symbolUse.getSymbolRef().getValue();
+ if (moduleManager.lookupSymbol(symbolName))
+ continue;
+
+ Operation *symbolDefClone =
+ parentModuleManager.lookupSymbol(symbolName)->clone();
+ symbolDefWorklist.push_back(symbolDefClone);
+ moduleManager.insert(symbolDefClone);
+ }
+ }
}
return kernelModule;
OpenPOWER on IntegriCloud