diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/IR/Module.h | 6 | ||||
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 32 | ||||
| -rw-r--r-- | mlir/test/Dialect/GPU/outlining.mlir | 18 |
3 files changed, 37 insertions, 19 deletions
diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index de27a495377..9ac985ff586 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -135,6 +135,12 @@ public: return symbolTable.lookup<T>(name); } + /// Look up a symbol with the specified name, returning null if no such + /// name exists. Names must never include the @ on them. + template <typename NameTy> Operation *lookupSymbol(NameTy &&name) const { + return symbolTable.lookup(name); + } + /// Insert a new symbol into the module, auto-renaming it as necessary. void insert(Operation *op) { symbolTable.insert(op); diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d9a1106270f..672beee56cd 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -182,21 +182,23 @@ private: builder.getUnitAttr()); ModuleManager moduleManager(kernelModule); - llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc}; - while (!funcsToInsert.empty()) { - FuncOp func = funcsToInsert.pop_back_val(); - moduleManager.insert(func); - - // TODO(b/141098412): Support any op with a callable interface. - func.walk([&](CallOp call) { - auto callee = call.callee(); - if (moduleManager.lookupSymbol<FuncOp>(callee)) - return; - - auto calleeFromParent = - parentModuleManager.lookupSymbol<FuncOp>(callee); - funcsToInsert.push_back(calleeFromParent.clone()); - }); + moduleManager.insert(kernelFunc); + + llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (Optional<SymbolTable::UseRange> symbolUses = + SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = symbolUse.getSymbolRef().getValue(); + if (moduleManager.lookupSymbol(symbolName)) + continue; + + Operation *symbolDefClone = + parentModuleManager.lookupSymbol(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + moduleManager.insert(symbolDefClone); + } + } } return kernelModule; diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir index 8398907b6c0..d138cfe3236 100644 --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -111,6 +111,8 @@ func @extra_constants(%arg0 : memref<?xf32>) { // ----- +llvm.mlir.global @global(42 : i64) : !llvm.i64 + func @function_call(%arg0 : memref<?xf32>) { %cst = constant 8 : index gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, @@ -119,6 +121,7 @@ func @function_call(%arg0 : memref<?xf32>) { %block_z = %cst) { call @device_function() : () -> () call @device_function() : () -> () + %0 = llvm.mlir.addressof @global : !llvm<"i64*"> gpu.return } return @@ -134,7 +137,14 @@ func @recursive_device_function() { gpu.return } -// CHECK: @device_function -// CHECK: @recursive_device_function -// CHECK: @device_function -// CHECK: @recursive_device_function +// CHECK: module @function_call_kernel attributes {gpu.kernel_module} { +// CHECK: func @function_call_kernel() +// CHECK: call @device_function() : () -> () +// CHECK: call @device_function() : () -> () +// CHECK: llvm.mlir.addressof @global : !llvm<"i64*"> +// +// CHECK: llvm.mlir.global @global(42 : i64) : !llvm.i64 +// +// CHECK: func @device_function() +// CHECK: func @recursive_device_function() +// CHECK-NOT: func @device_function |

