diff options
11 files changed, 165 insertions, 216 deletions
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 7fafb08aef2..9a15b41f7de 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -26,6 +26,10 @@ class OwningRewritePatternList; class ModuleOp; template <typename OpT> class OpPassBase; +/// Collect a set of patterns to convert from the GPU dialect to NVVM. +void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. std::unique_ptr<OpPassBase<ModuleOp>> createLowerGpuOpsToNVVMOpsPass(); diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h index e719dd4cbb4..d034212fc80 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -41,12 +41,9 @@ public: /// Get the canonical string name of the dialect. static StringRef getDialectName(); - /// Get the name of the attribute used to annotate external kernel functions. + /// Get the name of the attribute used to annotate outlined kernel functions. static StringRef getKernelFuncAttrName() { return "gpu.kernel"; } - /// Get the name of the attribute used to annotate kernel modules. - static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; } - /// Returns whether the given function is a kernel function, i.e., has the /// 'gpu.kernel' attribute. static bool isKernel(FuncOp function); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index aa1711e3f8e..a69fe81b0d3 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -49,37 +49,26 @@ namespace { // TODO(herhut): Move to shared location. static constexpr const char *kCubinAnnotation = "nvvm.cubin"; -/// A pass converting tagged kernel modules to cubin blobs. -/// -/// If tagged as a kernel module, each contained function is translated to NVVM -/// IR and further to PTX. A user provided CubinGenerator compiles the PTX to -/// GPU binary code, which is then attached as an attribute to the function. The -/// function body is erased. +/// A pass converting tagged kernel functions to cubin blobs. class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> { public: GpuKernelToCubinPass( CubinGenerator cubinGenerator = compilePtxToCubinForTesting) : cubinGenerator(cubinGenerator) {} + // Run the dialect converter on the module. void runOnModule() override { - if (!getModule().getAttrOfType<UnitAttr>( - gpu::GPUDialect::getKernelModuleAttrName())) - return; - // Make sure the NVPTX target is initialized. LLVMInitializeNVPTXTarget(); LLVMInitializeNVPTXTargetInfo(); LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXAsmPrinter(); - auto llvmModule = translateModuleToNVVMIR(getModule()); - if (!llvmModule) - return signalPassFailure(); - for (auto function : getModule().getOps<FuncOp>()) { - if (!gpu::GPUDialect::isKernel(function)) + if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) { continue; - if (failed(translateGpuKernelToCubinAnnotation(*llvmModule, function))) + } + if (failed(translateGpuKernelToCubinAnnotation(function))) signalPassFailure(); } } @@ -90,13 +79,8 @@ private: std::string translateModuleToPtx(llvm::Module &module, llvm::TargetMachine &target_machine); - - /// Converts llvmModule to cubin using the user-provded generator. OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, FuncOp &function); - - /// Translates llvmModule to cubin and assigns it to attribute of function. - LogicalResult translateGpuKernelToCubinAnnotation(llvm::Module &llvmModule, - FuncOp &function); + LogicalResult translateGpuKernelToCubinAnnotation(FuncOp &function); CubinGenerator cubinGenerator; }; @@ -151,13 +135,22 @@ OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule, return cubinGenerator(ptx, function); } -LogicalResult GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation( - llvm::Module &llvmModule, FuncOp &function) { - auto cubin = convertModuleToCubin(llvmModule, function); - if (!cubin) +LogicalResult +GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) { + Builder builder(function.getContext()); + + OwningModuleRef module = ModuleOp::create(function.getLoc()); + + // TODO(herhut): Also handle called functions. + module->push_back(function.clone()); + + auto llvmModule = translateModuleToNVVMIR(*module); + auto cubin = convertModuleToCubin(*llvmModule, function); + + if (!cubin) { return function.emitError("translation to CUDA binary failed."); + } - Builder builder(function.getContext()); function.setAttr(kCubinAnnotation, builder.getStringAttr({cubin->data(), cubin->size()})); diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index 83c3538324b..f8c6f5d15ff 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -43,15 +43,8 @@ constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter"; constexpr const char *kCubinGetterSuffix = "_cubin"; constexpr const char *kCubinStorageSuffix = "_cubin_cst"; -/// A pass which moves cubin from function attributes in nested modules -/// to global strings and generates getter functions. -/// -/// The GpuKernelToCubinPass annotates kernels functions with compiled device -/// code blobs. These functions reside in nested modules generated by -/// GpuKernelOutliningPass. This pass consumes these modules and moves the cubin -/// blobs back to the parent module as global strings and generates accessor -/// functions for them. The external kernel functions (also generated by the -/// outlining pass) are annotated with the symbol of the cubin accessor. +/// A pass generating global strings and getter functions for all cubin blobs +/// annotated on functions via the nvvm.cubin attribute. class GpuGenerateCubinAccessorsPass : public ModulePass<GpuGenerateCubinAccessorsPass> { private: @@ -62,25 +55,18 @@ private: } // Inserts a global constant string containing `blob` into the parent module - // of `kernelFunc` and generates the function that returns the address of the - // first character of this string. + // of `orig` and generates the function that returns the address of the first + // character of this string. // TODO(herhut): consider fusing this pass with launch-func-to-cuda. - void generate(FuncOp kernelFunc, StringAttr blob) { - auto stubFunc = getModule().lookupSymbol<FuncOp>(kernelFunc.getName()); - if (!stubFunc) { - kernelFunc.emitError( - "corresponding external function not found in parent module"); - return signalPassFailure(); - } - - Location loc = stubFunc.getLoc(); - SmallString<128> nameBuffer(stubFunc.getName()); - auto module = stubFunc.getParentOfType<ModuleOp>(); + void generate(FuncOp orig, StringAttr blob) { + Location loc = orig.getLoc(); + SmallString<128> nameBuffer(orig.getName()); + auto module = orig.getParentOfType<ModuleOp>(); assert(module && "function must belong to a module"); // Insert the getter function just after the original function. OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin()); - moduleBuilder.setInsertionPoint(stubFunc.getOperation()->getNextNode()); + moduleBuilder.setInsertionPoint(orig.getOperation()->getNextNode()); auto getterType = moduleBuilder.getFunctionType( llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect)); nameBuffer.append(kCubinGetterSuffix); @@ -89,7 +75,7 @@ private: Block *entryBlock = result.addEntryBlock(); // Drop the getter suffix before appending the storage suffix. - nameBuffer.resize(stubFunc.getName().size()); + nameBuffer.resize(orig.getName().size()); nameBuffer.append(kCubinStorageSuffix); // Obtain the address of the first character of the global string containing @@ -100,23 +86,21 @@ private: builder.create<LLVM::ReturnOp>(loc, startPtr); // Store the name of the getter on the function for easier lookup. - stubFunc.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result)); + orig.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result)); } public: + // Perform the conversion on the module. This may insert globals, so it + // cannot be done on multiple functions in parallel. void runOnModule() override { - llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); + llvmDialect = + getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); - auto modules = getModule().getOps<ModuleOp>(); - for (auto module : llvm::make_early_inc_range(modules)) { - if (!module.getAttrOfType<UnitAttr>( - gpu::GPUDialect::getKernelModuleAttrName())) + for (auto func : getModule().getOps<FuncOp>()) { + StringAttr cubinBlob = func.getAttrOfType<StringAttr>(kCubinAnnotation); + if (!cubinBlob) continue; - for (auto func : module.getOps<FuncOp>()) { - if (StringAttr blob = func.getAttrOfType<StringAttr>(kCubinAnnotation)) - generate(func, blob); - } - module.erase(); + generate(func, cubinBlob); } } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 0028badc4f4..1ae83ae9ae2 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -23,7 +23,6 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -39,6 +38,23 @@ using namespace mlir; namespace { +// Rewriting that replaces the types of a LaunchFunc operation with their +// LLVM counterparts. +struct GPULaunchFuncOpLowering : public LLVMOpLowering { +public: + explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_) {} + + // Convert the kernel arguments to an LLVM type, preserve the rest. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.clone(*op)->setOperands(operands); + return rewriter.replaceOp(op, llvm::None), matchSuccess(); + } +}; + // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension // that Op operates on. Op is assumed to return an `std.index` value and // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on @@ -103,31 +119,20 @@ public: } }; -// A pass that replaces all occurences of GPU device operations with their +// A pass that replaces all occurences of GPU operations with their // corresponding NVVM equivalent. // -// This pass only handles device code and is not meant to be run on GPU host -// code. +// This pass does not handle launching of kernels. Instead, it is meant to be +// used on the body region of a launch or the body region of a kernel +// function. class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> { public: void runOnModule() override { ModuleOp m = getModule(); - if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName())) - return; OwningRewritePatternList patterns; LLVMTypeConverter converter(m.getContext()); - populateStdToLLVMConversionPatterns(converter, patterns); - patterns.insert< - GPUIndexIntrinsicOpLowering<gpu::ThreadId, NVVM::ThreadIdXOp, - NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>, - GPUIndexIntrinsicOpLowering<gpu::BlockDim, NVVM::BlockDimXOp, - NVVM::BlockDimYOp, NVVM::BlockDimZOp>, - GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp, - NVVM::BlockIdYOp, NVVM::BlockIdZOp>, - GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp, - NVVM::GridDimYOp, NVVM::GridDimZOp>>( - converter); + populateGpuToNVVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); target.addLegalDialect<LLVM::LLVMDialect>(); @@ -141,6 +146,22 @@ public: } // anonymous namespace +/// Collect a set of patterns to convert from the GPU dialect to NVVM. +void mlir::populateGpuToNVVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns + .insert<GPULaunchFuncOpLowering, + GPUIndexIntrinsicOpLowering<gpu::ThreadId, NVVM::ThreadIdXOp, + NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>, + GPUIndexIntrinsicOpLowering<gpu::BlockDim, NVVM::BlockDimXOp, + NVVM::BlockDimYOp, NVVM::BlockDimZOp>, + GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp, + NVVM::BlockIdYOp, NVVM::BlockIdZOp>, + GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp, + NVVM::GridDimYOp, NVVM::GridDimZOp>>( + converter); +} + std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() { return std::make_unique<LowerGpuOpsToNVVMOpsPass>(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 9bf4cf6e643..4328fb39c29 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -93,7 +93,7 @@ static gpu::LaunchFuncOp inlineConstants(FuncOp kernelFunc, } // Outline the `gpu.launch` operation body into a kernel function. Replace -// `gpu.return` operations by `std.return` in the generated function. +// `gpu.return` operations by `std.return` in the generated functions. static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes()); @@ -107,7 +107,7 @@ static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); injectGpuIndexOperations(loc, outlinedFunc); - outlinedFunc.walk([](gpu::Return op) { + outlinedFunc.walk([](mlir::gpu::Return op) { OpBuilder replacer(op); replacer.create<ReturnOp>(op.getLoc()); op.erase(); @@ -131,44 +131,15 @@ static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) { namespace { -/// Pass that moves the kernel of each LaunchOp into its separate nested module. -/// -/// This pass moves the kernel code of each LaunchOp into a function created -/// inside a nested module. It also creates an external function of the same -/// name in the parent module. -/// -/// The kernel modules are intended to be compiled to a cubin blob independently -/// in a separate pass. The external functions can then be annotated with the -/// symbol of the cubin accessor function. class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { public: void runOnModule() override { ModuleManager moduleManager(getModule()); - auto context = getModule().getContext(); - Builder builder(context); for (auto func : getModule().getOps<FuncOp>()) { - // Insert just after the function. - Block::iterator insertPt(func.getOperation()->getNextNode()); - func.walk([&](gpu::LaunchOp op) { - // TODO(b/141098412): Handle called functions and globals. + func.walk([&](mlir::gpu::LaunchOp op) { FuncOp outlinedFunc = outlineKernelFunc(op); - - // Potentially renames outlinedFunc to make symbol unique. - moduleManager.insert(insertPt, outlinedFunc); - - // Potentially changes signature, pulling in constants. + moduleManager.insert(outlinedFunc); convertToLaunchFuncOp(op, outlinedFunc); - - // Create clone and move body from outlinedFunc. - auto kernelFunc = outlinedFunc.cloneWithoutRegions(); - kernelFunc.getBody().takeBody(outlinedFunc.getBody()); - - // Create nested module and insert kernelFunc. - auto kernelModule = ModuleOp::create(UnknownLoc::get(context)); - kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(), - builder.getUnitAttr()); - kernelModule.push_back(kernelFunc); - getModule().insert(insertPt, kernelModule); }); } } diff --git a/mlir/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir b/mlir/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir index 9e0907f7477..d2e291f57e7 100644 --- a/mlir/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir +++ b/mlir/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir @@ -2,14 +2,9 @@ // CHECK: llvm.mlir.global constant @[[global:.*]]("CUBIN") -module attributes {gpu.kernel_module} { - func @kernel(!llvm.float, !llvm<"float*">) - attributes {nvvm.cubin = "CUBIN"} -} - func @kernel(!llvm.float, !llvm<"float*">) -// CHECK: attributes {gpu.kernel, nvvm.cubingetter = @[[getter:.*]]} - attributes {gpu.kernel} +// CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN", nvvm.cubingetter = @[[getter:.*]]} + attributes {gpu.kernel, nvvm.cubin = "CUBIN"} // CHECK: func @[[getter]]() -> !llvm<"i8*"> // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]] diff --git a/mlir/test/Conversion/GPUToCUDA/lower-nvvm-kernel-to-cubin.mlir b/mlir/test/Conversion/GPUToCUDA/lower-nvvm-kernel-to-cubin.mlir index b6e19989203..8ddfc1996ef 100644 --- a/mlir/test/Conversion/GPUToCUDA/lower-nvvm-kernel-to-cubin.mlir +++ b/mlir/test/Conversion/GPUToCUDA/lower-nvvm-kernel-to-cubin.mlir @@ -1,26 +1,8 @@ -// RUN: mlir-opt %s --test-kernel-to-cubin -split-input-file | FileCheck %s - -module attributes {gpu.kernel_module} { - func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">) - // CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN"} - attributes { gpu.kernel } { - // CHECK-NOT: llvm.return - llvm.return - } -} - -// ----- - -module attributes {gpu.kernel_module} { - // CHECK: func @kernel_a - func @kernel_a() - attributes { gpu.kernel } { - llvm.return - } - - // CHECK: func @kernel_b - func @kernel_b() - attributes { gpu.kernel } { - llvm.return - } -} +// RUN: mlir-opt %s --test-kernel-to-cubin | FileCheck %s + +func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">) +// CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN"} + attributes { gpu.kernel } { +// CHECK-NOT: llvm.return + llvm.return +}
\ No newline at end of file diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 02637376622..cf8e7ed1113 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1,37 +1,35 @@ // RUN: mlir-opt %s -lower-gpu-ops-to-nvvm-ops | FileCheck %s -module attributes {gpu.kernel_module} { - // CHECK-LABEL: func @gpu_index_ops() - func @gpu_index_ops() - attributes { gpu.kernel } { - // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32 - %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32 - %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32 - %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index) +// CHECK-LABEL: func @gpu_index_ops() +func @gpu_index_ops() + attributes { gpu.kernel } { + // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32 + %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32 + %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32 + %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 - %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 - %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 - %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 + %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 + %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 + %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 - %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 - %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 - %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 + %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 + %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 + %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 - %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 - %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) - // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 - %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 + %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 + %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) + // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 + %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) - std.return - } + std.return } diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir index fdfe8d08115..07499a305ee 100644 --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -gpu-kernel-outlining -split-input-file %s | FileCheck %s // CHECK-LABEL: func @launch() func @launch() { @@ -35,11 +35,7 @@ func @launch() { } // CHECK-LABEL: func @launch_kernel -// CHECK-SAME: (f32, memref<?xf32, 1>) -// CHECK-NEXT: attributes {gpu.kernel} - -// CHECK-LABEL: func @launch_kernel -// CHECK-SAME: (%[[KERNEL_ARG0:.*]]: f32, %[[KERNEL_ARG1:.*]]: memref<?xf32, 1>) +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: memref<?xf32, 1>) // CHECK-NEXT: attributes {gpu.kernel} // CHECK-NEXT: %[[BID:.*]] = "gpu.block_id"() {dimension = "x"} : () -> index // CHECK-NEXT: = "gpu.block_id"() {dimension = "y"} : () -> index @@ -53,9 +49,9 @@ func @launch() { // CHECK-NEXT: %[[BDIM:.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index // CHECK-NEXT: = "gpu.block_dim"() {dimension = "y"} : () -> index // CHECK-NEXT: = "gpu.block_dim"() {dimension = "z"} : () -> index -// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> () +// CHECK-NEXT: "use"(%[[ARG0]]) : (f32) -> () // CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> () -// CHECK-NEXT: = load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1> +// CHECK-NEXT: = load %[[ARG1]][%[[TID]]] : memref<?xf32, 1> // ----- @@ -79,8 +75,8 @@ func @multiple_launches() { return } -// CHECK: func @multiple_launches_kernel() -// CHECK: func @multiple_launches_kernel_0() +// CHECK-LABEL: func @multiple_launches_kernel() +// CHECK-LABEL: func @multiple_launches_kernel_0() // ----- @@ -104,23 +100,3 @@ func @extra_constants(%arg0 : memref<?xf32>) { // CHECK-LABEL: func @extra_constants_kernel(%{{.*}}: memref<?xf32>) // CHECK: constant // CHECK: constant - -// ----- - -func @function_call(%arg0 : memref<?xf32>) { - %cst = constant 8 : index - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, - %grid_z = %cst) - threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, - %block_z = %cst) { - // TODO(b/141098412): Support function calls. - // expected-error @+1 {{'device_function' does not reference a valid function}} - call @device_function() : () -> () - gpu.return - } - return -} - -func @device_function() { - gpu.return -} diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index 26bf3c58768..deddc63eb10 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -108,22 +108,50 @@ OwnedCubin compilePtxToCubin(const std::string ptx, FuncOp &function) { return result; } +namespace { +// A pass that lowers all Standard and Gpu operations to LLVM dialect. It does +// not lower the GPULaunch operation to actual code but dows translate the +// signature of its kernel argument. +class LowerStandardAndGpuToLLVMAndNVVM + : public ModulePass<LowerStandardAndGpuToLLVMAndNVVM> { +public: + void runOnModule() override { + ModuleOp m = getModule(); + + OwningRewritePatternList patterns; + LLVMTypeConverter converter(m.getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + populateGpuToNVVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addLegalDialect<NVVM::NVVMDialect>(); + target.addLegalOp<ModuleOp>(); + target.addLegalOp<ModuleTerminatorOp>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed(applyFullConversion(m, target, patterns, &converter))) + signalPassFailure(); + } +}; +} // end anonymous namespace + static LogicalResult runMLIRPasses(ModuleOp m) { PassManager pm(m.getContext()); - applyPassManagerCLOptions(pm); pm.addPass(createGpuKernelOutliningPass()); - auto &kernelPm = pm.nest<ModuleOp>(); - kernelPm.addPass(createLowerGpuOpsToNVVMOpsPass()); - kernelPm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); - pm.addPass(createLowerToLLVMPass()); + pm.addPass(static_cast<std::unique_ptr<OpPassBase<ModuleOp>>>( + std::make_unique<LowerStandardAndGpuToLLVMAndNVVM>())); + pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); pm.addPass(createGenerateCubinAccessorPass()); pm.addPass(createConvertGpuLaunchFuncToCudaCallsPass()); - return pm.run(m); + if (failed(pm.run(m))) + return failure(); + + return success(); } int main(int argc, char **argv) { - registerPassManagerCLOptions(); return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); } |