diff options
Diffstat (limited to 'clang/lib/CodeGen/CGStmtOpenMP.cpp')
-rw-r--r-- | clang/lib/CodeGen/CGStmtOpenMP.cpp | 46 |
1 files changed, 37 insertions, 9 deletions
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index 2d14ec40cbb..ba531b96ccf 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -1213,10 +1213,9 @@ static void emitCommonOMPParallelDirective(CodeGenFunction &CGF, const OMPExecutableDirective &S, OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) { - auto CS = cast<CapturedStmt>(S.getAssociatedStmt()); - auto OutlinedFn = CGF.CGM.getOpenMPRuntime(). - emitParallelOrTeamsOutlinedFunction(S, - *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen); + const CapturedStmt *CS = S.getCapturedStmt(OMPD_parallel); + auto OutlinedFn = CGF.CGM.getOpenMPRuntime().emitParallelOutlinedFunction( + S, *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen); if (const auto *NumThreadsClause = S.getSingleClause<OMPNumThreadsClause>()) { CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF); auto NumThreads = CGF.EmitScalarExpr(NumThreadsClause->getNumThreads(), @@ -3497,10 +3496,9 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF, const OMPExecutableDirective &S, OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) { - auto CS = cast<CapturedStmt>(S.getAssociatedStmt()); - auto OutlinedFn = CGF.CGM.getOpenMPRuntime(). - emitParallelOrTeamsOutlinedFunction(S, - *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen); + const CapturedStmt *CS = S.getCapturedStmt(OMPD_teams); + auto OutlinedFn = CGF.CGM.getOpenMPRuntime().emitTeamsOutlinedFunction( + S, *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen); const OMPTeamsDirective &TD = *dyn_cast<OMPTeamsDirective>(&S); const OMPNumTeamsClause *NT = TD.getSingleClause<OMPNumTeamsClause>(); @@ -3755,9 +3753,39 @@ void CodeGenFunction::EmitOMPTargetExitDataDirective( CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device); } +static void emitTargetParallelRegion(CodeGenFunction &CGF, + const OMPTargetParallelDirective &S, + PrePostActionTy &Action) { + // Get the captured statement associated with the 'parallel' region. + auto *CS = S.getCapturedStmt(OMPD_parallel); + Action.Enter(CGF); + auto &&CodeGen = [CS](CodeGenFunction &CGF, PrePostActionTy &) { + // TODO: Add support for clauses. + CGF.EmitStmt(CS->getCapturedStmt()); + }; + emitCommonOMPParallelDirective(CGF, S, OMPD_parallel, CodeGen); +} + +void CodeGenFunction::EmitOMPTargetParallelDeviceFunction( + CodeGenModule &CGM, StringRef ParentName, + const OMPTargetParallelDirective &S) { + auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) { + emitTargetParallelRegion(CGF, S, Action); + }; + llvm::Function *Fn; + llvm::Constant *Addr; + // Emit target region as a standalone region. + CGM.getOpenMPRuntime().emitTargetOutlinedFunction( + S, ParentName, Fn, Addr, /*IsOffloadEntry=*/true, CodeGen); + assert(Fn && Addr && "Target device function emission failed."); +} + void CodeGenFunction::EmitOMPTargetParallelDirective( const OMPTargetParallelDirective &S) { - // TODO: codegen for target parallel. + auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) { + emitTargetParallelRegion(CGF, S, Action); + }; + emitCommonOMPTargetDirective(*this, S, CodeGen); } void CodeGenFunction::EmitOMPTargetParallelForDirective( |