diff options
Diffstat (limited to 'clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp')
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp | 52 |
1 files changed, 43 insertions, 9 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp index 5b7f0c3e43c..8cf5bb2f44b 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -187,7 +187,7 @@ class CheckVarsEscapingDeclContext final RecordDecl *GlobalizedRD = nullptr; llvm::SmallDenseMap<const ValueDecl *, const FieldDecl *> MappedDeclsFields; bool AllEscaped = false; - bool IsForParallelRegion = false; + bool IsForCombinedParallelRegion = false; static llvm::Optional<OMPDeclareTargetDeclAttr::MapTypeTy> isDeclareTargetDeclaration(const ValueDecl *VD) { @@ -210,7 +210,7 @@ class CheckVarsEscapingDeclContext final if (const FieldDecl *FD = CSI->lookup(cast<VarDecl>(VD))) { // Check if need to capture the variable that was already captured by // value in the outer region. - if (!IsForParallelRegion) { + if (!IsForCombinedParallelRegion) { if (!FD->hasAttrs()) return; const auto *Attr = FD->getAttr<OMPCaptureKindAttr>(); @@ -225,13 +225,13 @@ class CheckVarsEscapingDeclContext final assert(!VD->getType()->isVariablyModifiedType() && "Parameter captured by value with variably modified type"); EscapedParameters.insert(VD); - } else if (!IsForParallelRegion) { + } else if (!IsForCombinedParallelRegion) { return; } } } if ((!CGF.CapturedStmtInfo || - (IsForParallelRegion && CGF.CapturedStmtInfo)) && + (IsForCombinedParallelRegion && CGF.CapturedStmtInfo)) && VD->getType()->isReferenceType()) // Do not globalize variables with reference type. return; @@ -253,18 +253,49 @@ class CheckVarsEscapingDeclContext final } } } - void VisitOpenMPCapturedStmt(const CapturedStmt *S, bool IsParallelRegion) { + void VisitOpenMPCapturedStmt(const CapturedStmt *S, + ArrayRef<OMPClause *> Clauses, + bool IsCombinedParallelRegion) { if (!S) return; for (const CapturedStmt::Capture &C : S->captures()) { if (C.capturesVariable() && !C.capturesVariableByCopy()) { const ValueDecl *VD = C.getCapturedVar(); - bool SavedIsParallelRegion = IsForParallelRegion; - IsForParallelRegion = IsParallelRegion; + bool SavedIsForCombinedParallelRegion = IsForCombinedParallelRegion; + if (IsCombinedParallelRegion) { + // Check if the variable is privatized in the combined construct and + // those private copies must be shared in the inner parallel + // directive. + IsForCombinedParallelRegion = false; + for (const OMPClause *C : Clauses) { + if (!isOpenMPPrivate(C->getClauseKind()) || + C->getClauseKind() == OMPC_reduction || + C->getClauseKind() == OMPC_linear || + C->getClauseKind() == OMPC_private) + continue; + ArrayRef<const Expr *> Vars; + if (const auto *PC = dyn_cast<OMPFirstprivateClause>(C)) + Vars = PC->getVarRefs(); + else if (const auto *PC = dyn_cast<OMPLastprivateClause>(C)) + Vars = PC->getVarRefs(); + else + llvm_unreachable("Unexpected clause."); + for (const auto *E : Vars) { + const Decl *D = + cast<DeclRefExpr>(E)->getDecl()->getCanonicalDecl(); + if (D == VD->getCanonicalDecl()) { + IsForCombinedParallelRegion = true; + break; + } + } + if (IsForCombinedParallelRegion) + break; + } + } markAsEscaped(VD); if (isa<OMPCapturedExprDecl>(VD)) VisitValueDecl(VD); - IsForParallelRegion = SavedIsParallelRegion; + IsForCombinedParallelRegion = SavedIsForCombinedParallelRegion; } } } @@ -341,7 +372,10 @@ public: VisitStmt(S->getCapturedStmt()); return; } - VisitOpenMPCapturedStmt(S, CaptureRegions.back() == OMPD_parallel); + VisitOpenMPCapturedStmt( + S, D->clauses(), + CaptureRegions.back() == OMPD_parallel && + isOpenMPDistributeDirective(D->getDirectiveKind())); } } void VisitCapturedStmt(const CapturedStmt *S) { |