diff options
Diffstat (limited to 'clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp')
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp | 194 |
1 files changed, 171 insertions, 23 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp index 5a119003e5b..0b8d79e28b6 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -296,6 +296,7 @@ void CGOpenMPRuntimeNVPTX::emitGenericKernel(const OMPExecutableDirective &D, EntryFunctionState EST; WorkerFunctionState WST(CGM); Work.clear(); + WrapperFunctionsMap.clear(); // Emit target region as a standalone region. class NVPTXPrePostActionTy : public PrePostActionTy { @@ -468,7 +469,7 @@ static void setPropertyExecutionMode(CodeGenModule &CGM, StringRef Name, } void CGOpenMPRuntimeNVPTX::emitWorkerFunction(WorkerFunctionState &WST) { - auto &Ctx = CGM.getContext(); + ASTContext &Ctx = CGM.getContext(); CodeGenFunction CGF(CGM, /*suppressNewContext=*/true); CGF.disableDebugInfo(); @@ -511,7 +512,10 @@ void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF, CGF.InitTempAlloca(ExecStatus, Bld.getInt8(/*C=*/0)); CGF.InitTempAlloca(WorkFn, llvm::Constant::getNullValue(CGF.Int8PtrTy)); - llvm::Value *Args[] = {WorkFn.getPointer()}; + // Set up shared arguments + Address SharedArgs = + CGF.CreateDefaultAlignTempAlloca(CGF.Int8PtrPtrTy, "shared_args"); + llvm::Value *Args[] = {WorkFn.getPointer(), SharedArgs.getPointer()}; llvm::Value *Ret = CGF.EmitRuntimeCall( createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_parallel), Args); Bld.CreateStore(Bld.CreateZExt(Ret, CGF.Int8Ty), ExecStatus); @@ -530,6 +534,9 @@ void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF, // Signal start of parallel region. CGF.EmitBlock(ExecuteBB); + // Current context + ASTContext &Ctx = CGF.getContext(); + // Process work items: outlined parallel functions. for (auto *W : Work) { // Try to match this outlined function. @@ -545,14 +552,18 @@ void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF, // Execute this outlined function. CGF.EmitBlock(ExecuteFNBB); - // Insert call to work function. - // FIXME: Pass arguments to outlined function from master thread. - auto *Fn = cast<llvm::Function>(W); - Address ZeroAddr = - CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, /*Name=*/".zero.addr"); - CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C=*/0)); - llvm::Value *FnArgs[] = {ZeroAddr.getPointer(), ZeroAddr.getPointer()}; - emitCall(CGF, Fn, FnArgs); + // Insert call to work function via shared wrapper. The shared + // wrapper takes exactly three arguments: + // - the parallelism level; + // - the master thread ID; + // - the list of references to shared arguments. + // + // TODO: Assert that the function is a wrapper function.s + Address Capture = CGF.EmitLoadOfPointer(SharedArgs, + Ctx.getPointerType( + Ctx.getPointerType(Ctx.VoidPtrTy)).castAs<PointerType>()); + emitCall(CGF, W, {Bld.getInt16(/*ParallelLevel=*/0), + getMasterThreadID(CGF), Capture.getPointer()}); // Go to end of parallel region. CGF.EmitBranch(TerminateBB); @@ -618,16 +629,18 @@ CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) { } case OMPRTL_NVPTX__kmpc_kernel_prepare_parallel: { /// Build void __kmpc_kernel_prepare_parallel( - /// void *outlined_function); - llvm::Type *TypeParams[] = {CGM.Int8PtrTy}; + /// void *outlined_function, void ***args, kmp_int32 nArgs); + llvm::Type *TypeParams[] = {CGM.Int8PtrTy, + CGM.Int8PtrPtrTy->getPointerTo(0), CGM.Int32Ty}; llvm::FunctionType *FnTy = llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_prepare_parallel"); break; } case OMPRTL_NVPTX__kmpc_kernel_parallel: { - /// Build bool __kmpc_kernel_parallel(void **outlined_function); - llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy}; + /// Build bool __kmpc_kernel_parallel(void **outlined_function, void ***args); + llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy, + CGM.Int8PtrPtrTy->getPointerTo(0)}; llvm::Type *RetTy = CGM.getTypes().ConvertType(CGM.getContext().BoolTy); llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, TypeParams, /*isVarArg*/ false); @@ -846,8 +859,17 @@ void CGOpenMPRuntimeNVPTX::emitNumTeamsClause(CodeGenFunction &CGF, llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction( const OMPExecutableDirective &D, const VarDecl *ThreadIDVar, OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) { - return CGOpenMPRuntime::emitParallelOutlinedFunction(D, ThreadIDVar, - InnermostKind, CodeGen); + + auto *OutlinedFun = cast<llvm::Function>( + CGOpenMPRuntime::emitParallelOutlinedFunction( + D, ThreadIDVar, InnermostKind, CodeGen)); + if (!isInSpmdExecutionMode()) { + llvm::Function *WrapperFun = + createDataSharingWrapper(OutlinedFun, D); + WrapperFunctionsMap[OutlinedFun] = WrapperFun; + } + + return OutlinedFun; } llvm::Value *CGOpenMPRuntimeNVPTX::emitTeamsOutlinedFunction( @@ -899,15 +921,52 @@ void CGOpenMPRuntimeNVPTX::emitGenericParallelCall( CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn, ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) { llvm::Function *Fn = cast<llvm::Function>(OutlinedFn); + llvm::Function *WFn = WrapperFunctionsMap[Fn]; + assert(WFn && "Wrapper function does not exist!"); + + // Force inline this outlined function at its call site. + Fn->setLinkage(llvm::GlobalValue::InternalLinkage); - auto &&L0ParallelGen = [this, Fn](CodeGenFunction &CGF, PrePostActionTy &) { + auto &&L0ParallelGen = [this, WFn, &CapturedVars](CodeGenFunction &CGF, + PrePostActionTy &) { CGBuilderTy &Bld = CGF.Builder; - // Prepare for parallel region. Indicate the outlined function. - llvm::Value *Args[] = {Bld.CreateBitOrPointerCast(Fn, CGM.Int8PtrTy)}; - CGF.EmitRuntimeCall( - createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel), - Args); + llvm::Value *ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy); + + if (!CapturedVars.empty()) { + // Prepare for parallel region. Indicate the outlined function. + Address SharedArgs = + CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy, + "shared_args"); + llvm::Value *SharedArgsPtr = SharedArgs.getPointer(); + llvm::Value *Args[] = {ID, SharedArgsPtr, + Bld.getInt32(CapturedVars.size())}; + + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel), + Args); + + unsigned Idx = 0; + ASTContext &Ctx = CGF.getContext(); + for (llvm::Value *V : CapturedVars) { + Address Dst = Bld.CreateConstInBoundsGEP( + CGF.EmitLoadOfPointer(SharedArgs, + Ctx.getPointerType( + Ctx.getPointerType(Ctx.VoidPtrTy)).castAs<PointerType>()), + Idx, CGF.getPointerSize()); + llvm::Value *PtrV = Bld.CreateBitCast(V, CGF.VoidPtrTy); + CGF.EmitStoreOfScalar(PtrV, Dst, /*Volatile=*/false, + Ctx.getPointerType(Ctx.VoidPtrTy)); + Idx++; + } + } else { + llvm::Value *Args[] = {ID, + llvm::ConstantPointerNull::get(CGF.VoidPtrPtrTy->getPointerTo(0)), + /*nArgs=*/Bld.getInt32(0)}; + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel), + Args); + } // Activate workers. This barrier is used by the master to signal // work for the workers. @@ -922,7 +981,7 @@ void CGOpenMPRuntimeNVPTX::emitGenericParallelCall( syncCTAThreads(CGF); // Remember for post-processing in worker loop. - Work.push_back(Fn); + Work.emplace_back(WFn); }; auto *RTLoc = emitUpdateLocation(CGF, Loc); @@ -2318,3 +2377,92 @@ void CGOpenMPRuntimeNVPTX::emitOutlinedFunctionCall( } CGOpenMPRuntime::emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, TargetArgs); } + +/// Emit function which wraps the outline parallel region +/// and controls the arguments which are passed to this function. +/// The wrapper ensures that the outlined function is called +/// with the correct arguments when data is shared. +llvm::Function *CGOpenMPRuntimeNVPTX::createDataSharingWrapper( + llvm::Function *OutlinedParallelFn, const OMPExecutableDirective &D) { + ASTContext &Ctx = CGM.getContext(); + const auto &CS = *cast<CapturedStmt>(D.getAssociatedStmt()); + + // Create a function that takes as argument the source thread. + FunctionArgList WrapperArgs; + QualType Int16QTy = + Ctx.getIntTypeForBitwidth(/*DestWidth=*/16, /*Signed=*/false); + QualType Int32QTy = + Ctx.getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/false); + QualType Int32PtrQTy = Ctx.getPointerType(Int32QTy); + QualType VoidPtrPtrQTy = Ctx.getPointerType(Ctx.VoidPtrTy); + ImplicitParamDecl ParallelLevelArg(Ctx, Int16QTy, ImplicitParamDecl::Other); + ImplicitParamDecl WrapperArg(Ctx, Int32QTy, ImplicitParamDecl::Other); + ImplicitParamDecl SharedArgsList(Ctx, VoidPtrPtrQTy, + ImplicitParamDecl::Other); + WrapperArgs.emplace_back(&ParallelLevelArg); + WrapperArgs.emplace_back(&WrapperArg); + WrapperArgs.emplace_back(&SharedArgsList); + + auto &CGFI = + CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, WrapperArgs); + + auto *Fn = llvm::Function::Create( + CGM.getTypes().GetFunctionType(CGFI), llvm::GlobalValue::InternalLinkage, + OutlinedParallelFn->getName() + "_wrapper", &CGM.getModule()); + CGM.SetInternalFunctionAttributes(/*D=*/nullptr, Fn, CGFI); + Fn->setLinkage(llvm::GlobalValue::InternalLinkage); + + CodeGenFunction CGF(CGM, /*suppressNewContext=*/true); + CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, Fn, CGFI, WrapperArgs); + + const auto *RD = CS.getCapturedRecordDecl(); + auto CurField = RD->field_begin(); + + // Get the array of arguments. + SmallVector<llvm::Value *, 8> Args; + + // TODO: suppport SIMD and pass actual values + Args.emplace_back(llvm::ConstantPointerNull::get( + CGM.Int32Ty->getPointerTo())); + Args.emplace_back(llvm::ConstantPointerNull::get( + CGM.Int32Ty->getPointerTo())); + + CGBuilderTy &Bld = CGF.Builder; + auto CI = CS.capture_begin(); + + // Load the start of the array + auto SharedArgs = + CGF.EmitLoadOfPointer(CGF.GetAddrOfLocalVar(&SharedArgsList), + VoidPtrPtrQTy->castAs<PointerType>()); + + // For each captured variable + for (unsigned I = 0; I < CS.capture_size(); ++I, ++CI, ++CurField) { + // Name of captured variable + StringRef Name; + if (CI->capturesThis()) + Name = "this"; + else + Name = CI->getCapturedVar()->getName(); + + // We retrieve the CLANG type of the argument. We use it to create + // an alloca which will give us the LLVM type. + QualType ElemTy = CurField->getType(); + // If this is a capture by copy the element type has to be the pointer to + // the data. + if (CI->capturesVariableByCopy()) + ElemTy = Ctx.getPointerType(ElemTy); + + // Get shared address of the captured variable. + Address ArgAddress = Bld.CreateConstInBoundsGEP( + SharedArgs, I, CGF.getPointerSize()); + Address TypedArgAddress = Bld.CreateBitCast( + ArgAddress, CGF.ConvertTypeForMem(Ctx.getPointerType(ElemTy))); + llvm::Value *Arg = CGF.EmitLoadOfScalar(TypedArgAddress, + /*Volatile=*/false, Int32PtrQTy, SourceLocation()); + Args.emplace_back(Arg); + } + + emitCall(CGF, OutlinedParallelFn, Args); + CGF.FinishFunction(); + return Fn; +} |