summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/IR/Module.h6
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp32
-rw-r--r--mlir/test/Dialect/GPU/outlining.mlir18
3 files changed, 37 insertions, 19 deletions
diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h
index de27a495377..9ac985ff586 100644
--- a/mlir/include/mlir/IR/Module.h
+++ b/mlir/include/mlir/IR/Module.h
@@ -135,6 +135,12 @@ public:
return symbolTable.lookup<T>(name);
}
+ /// Look up a symbol with the specified name, returning null if no such
+ /// name exists. Names must never include the @ on them.
+ template <typename NameTy> Operation *lookupSymbol(NameTy &&name) const {
+ return symbolTable.lookup(name);
+ }
+
/// Insert a new symbol into the module, auto-renaming it as necessary.
void insert(Operation *op) {
symbolTable.insert(op);
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d9a1106270f..672beee56cd 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -182,21 +182,23 @@ private:
builder.getUnitAttr());
ModuleManager moduleManager(kernelModule);
- llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
- while (!funcsToInsert.empty()) {
- FuncOp func = funcsToInsert.pop_back_val();
- moduleManager.insert(func);
-
- // TODO(b/141098412): Support any op with a callable interface.
- func.walk([&](CallOp call) {
- auto callee = call.callee();
- if (moduleManager.lookupSymbol<FuncOp>(callee))
- return;
-
- auto calleeFromParent =
- parentModuleManager.lookupSymbol<FuncOp>(callee);
- funcsToInsert.push_back(calleeFromParent.clone());
- });
+ moduleManager.insert(kernelFunc);
+
+ llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
+ while (!symbolDefWorklist.empty()) {
+ if (Optional<SymbolTable::UseRange> symbolUses =
+ SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
+ for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ StringRef symbolName = symbolUse.getSymbolRef().getValue();
+ if (moduleManager.lookupSymbol(symbolName))
+ continue;
+
+ Operation *symbolDefClone =
+ parentModuleManager.lookupSymbol(symbolName)->clone();
+ symbolDefWorklist.push_back(symbolDefClone);
+ moduleManager.insert(symbolDefClone);
+ }
+ }
}
return kernelModule;
diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 8398907b6c0..d138cfe3236 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -111,6 +111,8 @@ func @extra_constants(%arg0 : memref<?xf32>) {
// -----
+llvm.mlir.global @global(42 : i64) : !llvm.i64
+
func @function_call(%arg0 : memref<?xf32>) {
%cst = constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
@@ -119,6 +121,7 @@ func @function_call(%arg0 : memref<?xf32>) {
%block_z = %cst) {
call @device_function() : () -> ()
call @device_function() : () -> ()
+ %0 = llvm.mlir.addressof @global : !llvm<"i64*">
gpu.return
}
return
@@ -134,7 +137,14 @@ func @recursive_device_function() {
gpu.return
}
-// CHECK: @device_function
-// CHECK: @recursive_device_function
-// CHECK: @device_function
-// CHECK: @recursive_device_function
+// CHECK: module @function_call_kernel attributes {gpu.kernel_module} {
+// CHECK: func @function_call_kernel()
+// CHECK: call @device_function() : () -> ()
+// CHECK: call @device_function() : () -> ()
+// CHECK: llvm.mlir.addressof @global : !llvm<"i64*">
+//
+// CHECK: llvm.mlir.global @global(42 : i64) : !llvm.i64
+//
+// CHECK: func @device_function()
+// CHECK: func @recursive_device_function()
+// CHECK-NOT: func @device_function
OpenPOWER on IntegriCloud