diff options
Diffstat (limited to 'clang/lib')
-rw-r--r-- | clang/lib/AST/StmtOpenMP.cpp | 33 | ||||
-rw-r--r-- | clang/lib/AST/StmtPrinter.cpp | 14 | ||||
-rw-r--r-- | clang/lib/Basic/OpenMPKinds.cpp | 8 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGStmtOpenMP.cpp | 27 | ||||
-rw-r--r-- | clang/lib/Parse/ParseOpenMP.cpp | 9 | ||||
-rw-r--r-- | clang/lib/Sema/SemaOpenMP.cpp | 119 |
6 files changed, 167 insertions, 43 deletions
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index b4e57952110..af90ea531df 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -795,13 +795,14 @@ OMPTargetDataDirective *OMPTargetDataDirective::CreateEmpty(const ASTContext &C, OMPTargetEnterDataDirective *OMPTargetEnterDataDirective::Create( const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, - ArrayRef<OMPClause *> Clauses) { + ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt) { void *Mem = C.Allocate( llvm::alignTo(sizeof(OMPTargetEnterDataDirective), alignof(OMPClause *)) + - sizeof(OMPClause *) * Clauses.size()); + sizeof(OMPClause *) * Clauses.size() + sizeof(Stmt *)); OMPTargetEnterDataDirective *Dir = new (Mem) OMPTargetEnterDataDirective(StartLoc, EndLoc, Clauses.size()); Dir->setClauses(Clauses); + Dir->setAssociatedStmt(AssociatedStmt); return Dir; } @@ -810,20 +811,20 @@ OMPTargetEnterDataDirective::CreateEmpty(const ASTContext &C, unsigned N, EmptyShell) { void *Mem = C.Allocate( llvm::alignTo(sizeof(OMPTargetEnterDataDirective), alignof(OMPClause *)) + - sizeof(OMPClause *) * N); + sizeof(OMPClause *) * N + sizeof(Stmt *)); return new (Mem) OMPTargetEnterDataDirective(N); } -OMPTargetExitDataDirective * -OMPTargetExitDataDirective::Create(const ASTContext &C, SourceLocation StartLoc, - SourceLocation EndLoc, - ArrayRef<OMPClause *> Clauses) { +OMPTargetExitDataDirective *OMPTargetExitDataDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt) { void *Mem = C.Allocate( llvm::alignTo(sizeof(OMPTargetExitDataDirective), alignof(OMPClause *)) + - sizeof(OMPClause *) * Clauses.size()); + sizeof(OMPClause *) * Clauses.size() + sizeof(Stmt *)); OMPTargetExitDataDirective *Dir = new (Mem) OMPTargetExitDataDirective(StartLoc, EndLoc, Clauses.size()); Dir->setClauses(Clauses); + Dir->setAssociatedStmt(AssociatedStmt); return Dir; } @@ -832,7 +833,7 @@ OMPTargetExitDataDirective::CreateEmpty(const ASTContext &C, unsigned N, EmptyShell) { void *Mem = C.Allocate( llvm::alignTo(sizeof(OMPTargetExitDataDirective), alignof(OMPClause *)) + - sizeof(OMPClause *) * N); + sizeof(OMPClause *) * N + sizeof(Stmt *)); return new (Mem) OMPTargetExitDataDirective(N); } @@ -1007,16 +1008,17 @@ OMPDistributeDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses, return new (Mem) OMPDistributeDirective(CollapsedNum, NumClauses); } -OMPTargetUpdateDirective * -OMPTargetUpdateDirective::Create(const ASTContext &C, SourceLocation StartLoc, - SourceLocation EndLoc, - ArrayRef<OMPClause *> Clauses) { +OMPTargetUpdateDirective *OMPTargetUpdateDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt) { unsigned Size = llvm::alignTo(sizeof(OMPTargetUpdateDirective), alignof(OMPClause *)); - void *Mem = C.Allocate(Size + sizeof(OMPClause *) * Clauses.size()); + void *Mem = + C.Allocate(Size + sizeof(OMPClause *) * Clauses.size() + sizeof(Stmt *)); OMPTargetUpdateDirective *Dir = new (Mem) OMPTargetUpdateDirective(StartLoc, EndLoc, Clauses.size()); Dir->setClauses(Clauses); + Dir->setAssociatedStmt(AssociatedStmt); return Dir; } @@ -1025,7 +1027,8 @@ OMPTargetUpdateDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses, EmptyShell) { unsigned Size = llvm::alignTo(sizeof(OMPTargetUpdateDirective), alignof(OMPClause *)); - void *Mem = C.Allocate(Size + sizeof(OMPClause *) * NumClauses); + void *Mem = + C.Allocate(Size + sizeof(OMPClause *) * NumClauses + sizeof(Stmt *)); return new (Mem) OMPTargetUpdateDirective(NumClauses); } diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index 09092743f0d..b367519425d 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -75,7 +75,8 @@ namespace { void PrintCallArgs(CallExpr *E); void PrintRawSEHExceptHandler(SEHExceptStmt *S); void PrintRawSEHFinallyStmt(SEHFinallyStmt *S); - void PrintOMPExecutableDirective(OMPExecutableDirective *S); + void PrintOMPExecutableDirective(OMPExecutableDirective *S, + bool ForceNoStmt = false); void PrintExpr(Expr *E) { if (E) @@ -1022,7 +1023,8 @@ void OMPClausePrinter::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *Node) { // OpenMP directives printing methods //===----------------------------------------------------------------------===// -void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S) { +void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S, + bool ForceNoStmt) { OMPClausePrinter Printer(OS, Policy); ArrayRef<OMPClause *> Clauses = S->clauses(); for (ArrayRef<OMPClause *>::iterator I = Clauses.begin(), E = Clauses.end(); @@ -1032,7 +1034,7 @@ void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S) { OS << ' '; } OS << "\n"; - if (S->hasAssociatedStmt() && S->getAssociatedStmt()) { + if (S->hasAssociatedStmt() && S->getAssociatedStmt() && !ForceNoStmt) { assert(isa<CapturedStmt>(S->getAssociatedStmt()) && "Expected captured statement!"); Stmt *CS = cast<CapturedStmt>(S->getAssociatedStmt())->getCapturedStmt(); @@ -1161,13 +1163,13 @@ void StmtPrinter::VisitOMPTargetDataDirective(OMPTargetDataDirective *Node) { void StmtPrinter::VisitOMPTargetEnterDataDirective( OMPTargetEnterDataDirective *Node) { Indent() << "#pragma omp target enter data "; - PrintOMPExecutableDirective(Node); + PrintOMPExecutableDirective(Node, /*ForceNoStmt=*/true); } void StmtPrinter::VisitOMPTargetExitDataDirective( OMPTargetExitDataDirective *Node) { Indent() << "#pragma omp target exit data "; - PrintOMPExecutableDirective(Node); + PrintOMPExecutableDirective(Node, /*ForceNoStmt=*/true); } void StmtPrinter::VisitOMPTargetParallelDirective( @@ -1219,7 +1221,7 @@ void StmtPrinter::VisitOMPDistributeDirective(OMPDistributeDirective *Node) { void StmtPrinter::VisitOMPTargetUpdateDirective( OMPTargetUpdateDirective *Node) { Indent() << "#pragma omp target update "; - PrintOMPExecutableDirective(Node); + PrintOMPExecutableDirective(Node, /*ForceNoStmt=*/true); } void StmtPrinter::VisitOMPDistributeParallelForDirective( diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp index 3dbcd4cdfbd..6cea8f4597c 100644 --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -935,6 +935,11 @@ void clang::getOpenMPCaptureRegions( CaptureRegions.push_back(OMPD_target); CaptureRegions.push_back(OMPD_parallel); break; + case OMPD_target_enter_data: + case OMPD_target_exit_data: + case OMPD_target_update: + CaptureRegions.push_back(OMPD_task); + break; case OMPD_threadprivate: case OMPD_taskyield: case OMPD_barrier: @@ -942,13 +947,10 @@ void clang::getOpenMPCaptureRegions( case OMPD_cancellation_point: case OMPD_cancel: case OMPD_flush: - case OMPD_target_enter_data: - case OMPD_target_exit_data: case OMPD_declare_reduction: case OMPD_declare_simd: case OMPD_declare_target: case OMPD_end_declare_target: - case OMPD_target_update: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index c8027e67a78..0c3d15e092c 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -4085,7 +4085,14 @@ void CodeGenFunction::EmitOMPTargetEnterDataDirective( if (auto *C = S.getSingleClause<OMPDeviceClause>()) Device = C->getDevice(); - CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device); + auto &&CodeGen = [&S, IfCond, Device](CodeGenFunction &CGF, + PrePostActionTy &) { + CGF.CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(CGF, S, IfCond, + Device); + }; + OMPLexicalScope Scope(*this, S, /*AsInlined=*/true); + CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_target_enter_data, + CodeGen); } void CodeGenFunction::EmitOMPTargetExitDataDirective( @@ -4105,7 +4112,14 @@ void CodeGenFunction::EmitOMPTargetExitDataDirective( if (auto *C = S.getSingleClause<OMPDeviceClause>()) Device = C->getDevice(); - CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device); + auto &&CodeGen = [&S, IfCond, Device](CodeGenFunction &CGF, + PrePostActionTy &) { + CGF.CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(CGF, S, IfCond, + Device); + }; + OMPLexicalScope Scope(*this, S, /*AsInlined=*/true); + CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_target_exit_data, + CodeGen); } static void emitTargetParallelRegion(CodeGenFunction &CGF, @@ -4404,5 +4418,12 @@ void CodeGenFunction::EmitOMPTargetUpdateDirective( if (auto *C = S.getSingleClause<OMPDeviceClause>()) Device = C->getDevice(); - CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device); + auto &&CodeGen = [&S, IfCond, Device](CodeGenFunction &CGF, + PrePostActionTy &) { + CGF.CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(CGF, S, IfCond, + Device); + }; + OMPLexicalScope Scope(*this, S, /*AsInlined=*/true); + CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_target_update, + CodeGen); } diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index e1685f6a9db..a67a5bbe0de 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -1086,6 +1086,15 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective( AssociatedStmt = ParseStatement(); Actions.ActOnFinishOfCompoundStmt(); AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); + } else if (DKind == OMPD_target_update || DKind == OMPD_target_enter_data || + DKind == OMPD_target_exit_data) { + Sema::CompoundScopeRAII CompoundScope(Actions); + Actions.ActOnOpenMPRegionStart(DKind, getCurScope()); + Actions.ActOnStartOfCompoundStmt(); + AssociatedStmt = + Actions.ActOnCompoundStmt(Loc, Loc, llvm::None, /*isStmtExpr=*/false); + Actions.ActOnFinishOfCompoundStmt(); + AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses); } Directive = Actions.ActOnOpenMPExecutableDirective( DKind, DirName, CancelRegion, Clauses, AssociatedStmt.get(), Loc, diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 320fcbcb455..771cd8452b0 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -2247,6 +2247,32 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { ParamsParallel); break; } + case OMPD_target_update: + case OMPD_target_enter_data: + case OMPD_target_exit_data: { + QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1); + QualType Args[] = {Context.VoidPtrTy.withConst().withRestrict()}; + FunctionProtoType::ExtProtoInfo EPI; + EPI.Variadic = true; + QualType CopyFnType = Context.getFunctionType(Context.VoidTy, Args, EPI); + Sema::CapturedParamNameType Params[] = { + std::make_pair(".global_tid.", KmpInt32Ty), + std::make_pair(".part_id.", Context.getPointerType(KmpInt32Ty)), + std::make_pair(".privates.", Context.VoidPtrTy.withConst()), + std::make_pair(".copy_fn.", + Context.getPointerType(CopyFnType).withConst()), + std::make_pair(".task_t.", Context.VoidPtrTy.withConst()), + std::make_pair(StringRef(), QualType()) // __context with shared vars + }; + ActOnCapturedRegionStart(DSAStack->getConstructLoc(), CurScope, CR_OpenMP, + Params); + // Mark this captured region as inlined, because we don't use outlined + // function directly. + getCurCapturedRegion()->TheCapturedDecl->addAttr( + AlwaysInlineAttr::CreateImplicit( + Context, AlwaysInlineAttr::Keyword_forceinline, SourceRange())); + break; + } case OMPD_threadprivate: case OMPD_taskyield: case OMPD_barrier: @@ -2254,13 +2280,10 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) { case OMPD_cancellation_point: case OMPD_cancel: case OMPD_flush: - case OMPD_target_enter_data: - case OMPD_target_exit_data: case OMPD_declare_reduction: case OMPD_declare_simd: case OMPD_declare_target: case OMPD_end_declare_target: - case OMPD_target_update: llvm_unreachable("OpenMP Directive is not allowed"); case OMPD_unknown: llvm_unreachable("Unknown OpenMP directive"); @@ -2993,12 +3016,12 @@ StmtResult Sema::ActOnOpenMPExecutableDirective( break; case OMPD_target_enter_data: Res = ActOnOpenMPTargetEnterDataDirective(ClausesWithImplicit, StartLoc, - EndLoc); + EndLoc, AStmt); AllowedNameModifiers.push_back(OMPD_target_enter_data); break; case OMPD_target_exit_data: Res = ActOnOpenMPTargetExitDataDirective(ClausesWithImplicit, StartLoc, - EndLoc); + EndLoc, AStmt); AllowedNameModifiers.push_back(OMPD_target_exit_data); break; case OMPD_taskloop: @@ -3016,9 +3039,8 @@ StmtResult Sema::ActOnOpenMPExecutableDirective( EndLoc, VarsWithInheritedDSA); break; case OMPD_target_update: - assert(!AStmt && "Statement is not allowed for target update"); - Res = - ActOnOpenMPTargetUpdateDirective(ClausesWithImplicit, StartLoc, EndLoc); + Res = ActOnOpenMPTargetUpdateDirective(ClausesWithImplicit, StartLoc, + EndLoc, AStmt); AllowedNameModifiers.push_back(OMPD_target_update); break; case OMPD_distribute_parallel_for: @@ -6423,7 +6445,28 @@ StmtResult Sema::ActOnOpenMPTargetDataDirective(ArrayRef<OMPClause *> Clauses, StmtResult Sema::ActOnOpenMPTargetEnterDataDirective(ArrayRef<OMPClause *> Clauses, SourceLocation StartLoc, - SourceLocation EndLoc) { + SourceLocation EndLoc, Stmt *AStmt) { + if (!AStmt) + return StmtError(); + + CapturedStmt *CS = cast<CapturedStmt>(AStmt); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + for (int ThisCaptureLevel = getOpenMPCaptureLevels(OMPD_target_enter_data); + ThisCaptureLevel > 1; --ThisCaptureLevel) { + CS = cast<CapturedStmt>(CS->getCapturedStmt()); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + } + // OpenMP [2.10.2, Restrictions, p. 99] // At least one map clause must appear on the directive. if (!hasClauses(Clauses, OMPC_map)) { @@ -6432,14 +6475,35 @@ Sema::ActOnOpenMPTargetEnterDataDirective(ArrayRef<OMPClause *> Clauses, return StmtError(); } - return OMPTargetEnterDataDirective::Create(Context, StartLoc, EndLoc, - Clauses); + return OMPTargetEnterDataDirective::Create(Context, StartLoc, EndLoc, Clauses, + AStmt); } StmtResult Sema::ActOnOpenMPTargetExitDataDirective(ArrayRef<OMPClause *> Clauses, SourceLocation StartLoc, - SourceLocation EndLoc) { + SourceLocation EndLoc, Stmt *AStmt) { + if (!AStmt) + return StmtError(); + + CapturedStmt *CS = cast<CapturedStmt>(AStmt); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + for (int ThisCaptureLevel = getOpenMPCaptureLevels(OMPD_target_exit_data); + ThisCaptureLevel > 1; --ThisCaptureLevel) { + CS = cast<CapturedStmt>(CS->getCapturedStmt()); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + } + // OpenMP [2.10.3, Restrictions, p. 102] // At least one map clause must appear on the directive. if (!hasClauses(Clauses, OMPC_map)) { @@ -6448,17 +6512,41 @@ Sema::ActOnOpenMPTargetExitDataDirective(ArrayRef<OMPClause *> Clauses, return StmtError(); } - return OMPTargetExitDataDirective::Create(Context, StartLoc, EndLoc, Clauses); + return OMPTargetExitDataDirective::Create(Context, StartLoc, EndLoc, Clauses, + AStmt); } StmtResult Sema::ActOnOpenMPTargetUpdateDirective(ArrayRef<OMPClause *> Clauses, SourceLocation StartLoc, - SourceLocation EndLoc) { + SourceLocation EndLoc, + Stmt *AStmt) { + if (!AStmt) + return StmtError(); + + CapturedStmt *CS = cast<CapturedStmt>(AStmt); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + for (int ThisCaptureLevel = getOpenMPCaptureLevels(OMPD_target_update); + ThisCaptureLevel > 1; --ThisCaptureLevel) { + CS = cast<CapturedStmt>(CS->getCapturedStmt()); + // 1.2.2 OpenMP Language Terminology + // Structured block - An executable statement with a single entry at the + // top and a single exit at the bottom. + // The point of exit cannot be a branch out of the structured block. + // longjmp() and throw() must not violate the entry/exit criteria. + CS->getCapturedDecl()->setNothrow(); + } + if (!hasClauses(Clauses, OMPC_to, OMPC_from)) { Diag(StartLoc, diag::err_omp_at_least_one_motion_clause_required); return StmtError(); } - return OMPTargetUpdateDirective::Create(Context, StartLoc, EndLoc, Clauses); + return OMPTargetUpdateDirective::Create(Context, StartLoc, EndLoc, Clauses, + AStmt); } StmtResult Sema::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, @@ -6861,7 +6949,6 @@ StmtResult Sema::ActOnOpenMPTargetSimdDirective( CS->getCapturedDecl()->setNothrow(); } - OMPLoopDirective::HelperExprs B; // In presence of clause 'collapse' with number of loops, it will define the // nested loops number. |