diff options
Diffstat (limited to 'clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp')
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp | 98 |
1 files changed, 55 insertions, 43 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp index 9452bdea4c7..e3cec13f7d5 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -184,9 +184,10 @@ class CheckVarsEscapingDeclContext final llvm::SetVector<const ValueDecl *> EscapedDecls; llvm::SetVector<const ValueDecl *> EscapedVariableLengthDecls; llvm::SmallPtrSet<const Decl *, 4> EscapedParameters; - bool AllEscaped = false; RecordDecl *GlobalizedRD = nullptr; llvm::SmallDenseMap<const ValueDecl *, const FieldDecl *> MappedDeclsFields; + bool AllEscaped = false; + bool IsForParallelRegion = false; static llvm::Optional<OMPDeclareTargetDeclAttr::MapTypeTy> isDeclareTargetDeclaration(const ValueDecl *VD) { @@ -207,23 +208,32 @@ class CheckVarsEscapingDeclContext final // Variables captured by value must be globalized. if (auto *CSI = CGF.CapturedStmtInfo) { if (const FieldDecl *FD = CSI->lookup(cast<VarDecl>(VD))) { - if (!FD->hasAttrs()) - return; - const auto *Attr = FD->getAttr<OMPCaptureKindAttr>(); - if (!Attr) - return; - if (!isOpenMPPrivate( - static_cast<OpenMPClauseKind>(Attr->getCaptureKind())) || - Attr->getCaptureKind() == OMPC_map) - return; - if (FD->getType()->isReferenceType()) + // Check if need to capture the variable that was already captured by + // value in the outer region. + if (!IsForParallelRegion) { + if (!FD->hasAttrs()) + return; + const auto *Attr = FD->getAttr<OMPCaptureKindAttr>(); + if (!Attr) + return; + if (!isOpenMPPrivate( + static_cast<OpenMPClauseKind>(Attr->getCaptureKind())) || + Attr->getCaptureKind() == OMPC_map) + return; + } + if (!FD->getType()->isReferenceType()) { + assert(!VD->getType()->isVariablyModifiedType() && + "Parameter captured by value with variably modified type"); + EscapedParameters.insert(VD); + } else if (!IsForParallelRegion) { return; - assert(!VD->getType()->isVariablyModifiedType() && - "Parameter captured by value with variably modified type"); - EscapedParameters.insert(VD); + } } - } else if (VD->getType()->isReferenceType()) - // Do not globalize variables with reference or pointer type. + } + if ((!CGF.CapturedStmtInfo || + (IsForParallelRegion && CGF.CapturedStmtInfo)) && + VD->getType()->isReferenceType()) + // Do not globalize variables with reference type. return; if (VD->getType()->isVariablyModifiedType()) EscapedVariableLengthDecls.insert(VD); @@ -243,15 +253,18 @@ class CheckVarsEscapingDeclContext final } } } - void VisitOpenMPCapturedStmt(const CapturedStmt *S) { + void VisitOpenMPCapturedStmt(const CapturedStmt *S, bool IsParallelRegion) { if (!S) return; for (const CapturedStmt::Capture &C : S->captures()) { if (C.capturesVariable() && !C.capturesVariableByCopy()) { const ValueDecl *VD = C.getCapturedVar(); + bool SavedIsParallelRegion = IsForParallelRegion; + IsForParallelRegion = IsParallelRegion; markAsEscaped(VD); if (isa<OMPCapturedExprDecl>(VD)) VisitValueDecl(VD); + IsForParallelRegion = SavedIsParallelRegion; } } } @@ -316,20 +329,19 @@ public: void VisitOMPExecutableDirective(const OMPExecutableDirective *D) { if (!D) return; - if (D->hasAssociatedStmt()) { - if (const auto *S = - dyn_cast_or_null<CapturedStmt>(D->getAssociatedStmt())) { - // Do not analyze directives that do not actually require capturing, - // like `omp for` or `omp simd` directives. - llvm::SmallVector<OpenMPDirectiveKind, 4> CaptureRegions; - getOpenMPCaptureRegions(CaptureRegions, D->getDirectiveKind()); - if (CaptureRegions.size() == 1 && - CaptureRegions.back() == OMPD_unknown) { - VisitStmt(S->getCapturedStmt()); - return; - } - VisitOpenMPCapturedStmt(S); + if (!D->hasAssociatedStmt()) + return; + if (const auto *S = + dyn_cast_or_null<CapturedStmt>(D->getAssociatedStmt())) { + // Do not analyze directives that do not actually require capturing, + // like `omp for` or `omp simd` directives. + llvm::SmallVector<OpenMPDirectiveKind, 4> CaptureRegions; + getOpenMPCaptureRegions(CaptureRegions, D->getDirectiveKind()); + if (CaptureRegions.size() == 1 && CaptureRegions.back() == OMPD_unknown) { + VisitStmt(S->getCapturedStmt()); + return; } + VisitOpenMPCapturedStmt(S, CaptureRegions.back() == OMPD_parallel); } } void VisitCapturedStmt(const CapturedStmt *S) { @@ -551,9 +563,9 @@ static void syncParallelThreads(CodeGenFunction &CGF, llvm::Value *NumThreads) { /// CTA. The threads in the last warp are reserved for master execution. /// For the 'spmd' execution mode, all threads in a CTA are part of the team. static llvm::Value *getThreadLimit(CodeGenFunction &CGF, - bool IsInSpmdExecutionMode = false) { + bool IsInSPMDExecutionMode = false) { CGBuilderTy &Bld = CGF.Builder; - return IsInSpmdExecutionMode + return IsInSPMDExecutionMode ? getNVPTXNumThreads(CGF) : Bld.CreateNUWSub(getNVPTXNumThreads(CGF), getNVPTXWarpSize(CGF), "thread_limit"); @@ -930,7 +942,7 @@ void CGOpenMPRuntimeNVPTX::emitNonSPMDEntryFooter(CodeGenFunction &CGF, EST.ExitBB = nullptr; } -void CGOpenMPRuntimeNVPTX::emitSpmdKernel(const OMPExecutableDirective &D, +void CGOpenMPRuntimeNVPTX::emitSPMDKernel(const OMPExecutableDirective &D, StringRef ParentName, llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID, @@ -951,10 +963,10 @@ void CGOpenMPRuntimeNVPTX::emitSpmdKernel(const OMPExecutableDirective &D, const OMPExecutableDirective &D) : RT(RT), EST(EST), D(D) {} void Enter(CodeGenFunction &CGF) override { - RT.emitSpmdEntryHeader(CGF, EST, D); + RT.emitSPMDEntryHeader(CGF, EST, D); } void Exit(CodeGenFunction &CGF) override { - RT.emitSpmdEntryFooter(CGF, EST); + RT.emitSPMDEntryFooter(CGF, EST); } } Action(*this, EST, D); CodeGen.setAction(Action); @@ -962,7 +974,7 @@ void CGOpenMPRuntimeNVPTX::emitSpmdKernel(const OMPExecutableDirective &D, IsOffloadEntry, CodeGen); } -void CGOpenMPRuntimeNVPTX::emitSpmdEntryHeader( +void CGOpenMPRuntimeNVPTX::emitSPMDEntryHeader( CodeGenFunction &CGF, EntryFunctionState &EST, const OMPExecutableDirective &D) { CGBuilderTy &Bld = CGF.Builder; @@ -974,7 +986,7 @@ void CGOpenMPRuntimeNVPTX::emitSpmdEntryHeader( // Initialize the OMP state in the runtime; called by all active threads. // TODO: Set RequiresOMPRuntime and RequiresDataSharing parameters // based on code analysis of the target region. - llvm::Value *Args[] = {getThreadLimit(CGF, /*IsInSpmdExecutionMode=*/true), + llvm::Value *Args[] = {getThreadLimit(CGF, /*IsInSPMDExecutionMode=*/true), /*RequiresOMPRuntime=*/Bld.getInt16(1), /*RequiresDataSharing=*/Bld.getInt16(1)}; CGF.EmitRuntimeCall( @@ -986,7 +998,7 @@ void CGOpenMPRuntimeNVPTX::emitSpmdEntryHeader( IsInTargetMasterThreadRegion = true; } -void CGOpenMPRuntimeNVPTX::emitSpmdEntryFooter(CodeGenFunction &CGF, +void CGOpenMPRuntimeNVPTX::emitSPMDEntryFooter(CodeGenFunction &CGF, EntryFunctionState &EST) { IsInTargetMasterThreadRegion = false; if (!CGF.HaveInsertPoint()) @@ -1465,7 +1477,7 @@ void CGOpenMPRuntimeNVPTX::emitTargetOutlinedFunction( bool Mode = supportsSPMDExecutionMode(CGM.getContext(), D); if (Mode) - emitSpmdKernel(D, ParentName, OutlinedFn, OutlinedFnID, IsOffloadEntry, + emitSPMDKernel(D, ParentName, OutlinedFn, OutlinedFnID, IsOffloadEntry, CodeGen); else emitNonSPMDKernel(D, ParentName, OutlinedFn, OutlinedFnID, IsOffloadEntry, @@ -1483,7 +1495,7 @@ CGOpenMPRuntimeNVPTX::CGOpenMPRuntimeNVPTX(CodeGenModule &CGM) void CGOpenMPRuntimeNVPTX::emitProcBindClause(CodeGenFunction &CGF, OpenMPProcBindClauseKind ProcBind, SourceLocation Loc) { - // Do nothing in case of Spmd mode and L0 parallel. + // Do nothing in case of SPMD mode and L0 parallel. if (getExecutionMode() == CGOpenMPRuntimeNVPTX::EM_SPMD) return; @@ -1493,7 +1505,7 @@ void CGOpenMPRuntimeNVPTX::emitProcBindClause(CodeGenFunction &CGF, void CGOpenMPRuntimeNVPTX::emitNumThreadsClause(CodeGenFunction &CGF, llvm::Value *NumThreads, SourceLocation Loc) { - // Do nothing in case of Spmd mode and L0 parallel. + // Do nothing in case of SPMD mode and L0 parallel. if (getExecutionMode() == CGOpenMPRuntimeNVPTX::EM_SPMD) return; @@ -1718,7 +1730,7 @@ void CGOpenMPRuntimeNVPTX::emitParallelCall( return; if (getExecutionMode() == CGOpenMPRuntimeNVPTX::EM_SPMD) - emitSpmdParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond); + emitSPMDParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond); else emitNonSPMDParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond); } @@ -1904,7 +1916,7 @@ void CGOpenMPRuntimeNVPTX::emitNonSPMDParallelCall( } } -void CGOpenMPRuntimeNVPTX::emitSpmdParallelCall( +void CGOpenMPRuntimeNVPTX::emitSPMDParallelCall( CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn, ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) { // Just call the outlined function to execute the parallel region. |