diff options
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntime.cpp | 22 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntime.h | 13 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGStmtOpenMP.cpp | 26 |
3 files changed, 54 insertions, 7 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index 53d4b6cb404..1a50c94bc7d 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -296,6 +296,16 @@ CGOpenMPRuntime::CreateRuntimeFunction(OpenMPRTLFunction Function) { RTLFn = CGM.CreateRuntimeFunction(FnTy, /*Name*/ "__kmpc_barrier"); break; } + case OMPRTL__kmpc_push_num_threads: { + // Build void __kmpc_push_num_threads(ident_t *loc, kmp_int32 global_tid, + // kmp_int32 num_threads) + llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty, + CGM.Int32Ty}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_push_num_threads"); + break; + } case OMPRTL__kmpc_serialized_parallel: { // Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32 // global_tid); @@ -431,3 +441,15 @@ void CGOpenMPRuntime::EmitOMPBarrierCall(CodeGenFunction &CGF, CGF.EmitRuntimeCall(RTLFn, Args); } +void CGOpenMPRuntime::EmitOMPNumThreadsClause(CodeGenFunction &CGF, + llvm::Value *NumThreads, + SourceLocation Loc) { + // Build call __kmpc_push_num_threads(&loc, global_tid, num_threads) + llvm::Value *Args[] = { + EmitOpenMPUpdateLocation(CGF, Loc), GetOpenMPThreadID(CGF, Loc), + CGF.Builder.CreateIntCast(NumThreads, CGF.Int32Ty, /*isSigned*/ true)}; + llvm::Constant *RTLFn = CGF.CGM.getOpenMPRuntime().CreateRuntimeFunction( + CGOpenMPRuntime::OMPRTL__kmpc_push_num_threads); + CGF.EmitRuntimeCall(RTLFn, Args); +} + diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h index 04378821d7a..c4da375acf5 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.h +++ b/clang/lib/CodeGen/CGOpenMPRuntime.h @@ -80,7 +80,10 @@ public: OMPRTL__kmpc_serialized_parallel, // Call to void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32 // global_tid); - OMPRTL__kmpc_end_serialized_parallel + OMPRTL__kmpc_end_serialized_parallel, + // Call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32 global_tid, + // kmp_int32 num_threads); + OMPRTL__kmpc_push_num_threads }; private: @@ -250,6 +253,14 @@ public: /// virtual void EmitOMPBarrierCall(CodeGenFunction &CGF, SourceLocation Loc, OpenMPLocationFlags Flags); + + /// \brief Emits call to void __kmpc_push_num_threads(ident_t *loc, kmp_int32 + /// global_tid, kmp_int32 num_threads) to generate code for 'num_threads' + /// clause. + /// \param NumThreads An integer value of threads. + virtual void EmitOMPNumThreadsClause(CodeGenFunction &CGF, + llvm::Value *NumThreads, + SourceLocation Loc); }; } // namespace CodeGen } // namespace clang diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index a459d07a723..2b0fb043be0 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -183,6 +183,23 @@ void CodeGenFunction::EmitOMPFirstprivateClause( } } +/// \brief Emits code for OpenMP parallel directive in the parallel region. +static void EmitOMPParallelCall(CodeGenFunction &CGF, + const OMPParallelDirective &S, + llvm::Value *OutlinedFn, + llvm::Value *CapturedStruct) { + if (auto C = S.getSingleClause(/*K*/ OMPC_num_threads)) { + CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF); + auto NumThreadsClause = cast<OMPNumThreadsClause>(C); + auto NumThreads = CGF.EmitScalarExpr(NumThreadsClause->getNumThreads(), + /*IgnoreResultAssign*/ true); + CGF.CGM.getOpenMPRuntime().EmitOMPNumThreadsClause( + CGF, NumThreads, NumThreadsClause->getLocStart()); + } + CGF.CGM.getOpenMPRuntime().EmitOMPParallelCall(CGF, S.getLocStart(), + OutlinedFn, CapturedStruct); +} + void CodeGenFunction::EmitOMPParallelDirective(const OMPParallelDirective &S) { auto CS = cast<CapturedStmt>(S.getAssociatedStmt()); auto CapturedStruct = GenerateCapturedStmtArgument(*CS); @@ -192,16 +209,13 @@ void CodeGenFunction::EmitOMPParallelDirective(const OMPParallelDirective &S) { auto Cond = cast<OMPIfClause>(C)->getCondition(); EmitOMPIfClause(*this, Cond, [&](bool ThenBlock) { if (ThenBlock) - CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(), - OutlinedFn, CapturedStruct); + EmitOMPParallelCall(*this, S, OutlinedFn, CapturedStruct); else CGM.getOpenMPRuntime().EmitOMPSerialCall(*this, S.getLocStart(), OutlinedFn, CapturedStruct); }); - } else { - CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(), - OutlinedFn, CapturedStruct); - } + } else + EmitOMPParallelCall(*this, S, OutlinedFn, CapturedStruct); } void CodeGenFunction::EmitOMPLoopBody(const OMPLoopDirective &S, |