diff options
| author | Christian Sigg <csigg@google.com> | 2019-11-19 13:12:19 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-19 13:13:02 -0800 |
| commit | f868adafee91a8c3ebee1e052d5fdfff7be0afd0 (patch) | |
| tree | 34ef3b730148fa01a2979dac18c0dbead9ad9a86 /mlir/lib/Conversion/GPUToCUDA | |
| parent | ee95f6f2594e9089990024208d01634fd81d2da2 (diff) | |
| download | bcm5719-llvm-f868adafee91a8c3ebee1e052d5fdfff7be0afd0.tar.gz bcm5719-llvm-f868adafee91a8c3ebee1e052d5fdfff7be0afd0.zip | |
Make type and rank explicit in mcuMemHostRegister function.
Fix registered size of indirect MemRefType kernel arguments.
PiperOrigin-RevId: 281362940
Diffstat (limited to 'mlir/lib/Conversion/GPUToCUDA')
| -rw-r--r-- | mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index d9332428425..9d8c8942051 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -49,7 +49,7 @@ static constexpr const char *cuModuleGetFunctionName = "mcuModuleGetFunction"; static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel"; static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper"; static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize"; -static constexpr const char *kMcuMemHostRegisterPtr = "mcuMemHostRegisterPtr"; +static constexpr const char *kMcuMemHostRegister = "mcuMemHostRegister"; static constexpr const char *kCubinAnnotation = "nvvm.cubin"; static constexpr const char *kCubinStorageSuffix = "_cubin_cst"; @@ -228,13 +228,13 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { getPointerType() /* CUstream stream */, /*isVarArg=*/false)); } - if (!module.lookupSymbol(kMcuMemHostRegisterPtr)) { + if (!module.lookupSymbol(kMcuMemHostRegister)) { builder.create<LLVM::LLVMFuncOp>( - loc, kMcuMemHostRegisterPtr, + loc, kMcuMemHostRegister, LLVM::LLVMType::getFunctionTy(getVoidType(), { getPointerType(), /* void *ptr */ - getInt32Type() /* int32 flags*/ + getInt64Type() /* int64 sizeBytes*/ }, /*isVarArg=*/false)); } @@ -277,12 +277,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // the descriptor pointer is registered via @mcuMemHostRegisterPtr if (llvmType.isStructTy()) { auto registerFunc = - getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegisterPtr); - auto zero = builder.create<LLVM::ConstantOp>( - loc, getInt32Type(), builder.getI32IntegerAttr(0)); + getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister); + auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo()); + auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(), + ArrayRef<Value *>{nullPtr, one}); + auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep); builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, builder.getSymbolRefAttr(registerFunc), - ArrayRef<Value *>{casted, zero}); + ArrayRef<Value *>{casted, size}); Value *memLocation = builder.create<LLVM::AllocaOp>( loc, getPointerPointerType(), one, /*alignment=*/1); builder.create<LLVM::StoreOp>(loc, casted, memLocation); |

