diff options
author | Alexey Bataev <a.bataev@hotmail.com> | 2018-05-07 17:23:05 +0000 |
---|---|---|
committer | Alexey Bataev <a.bataev@hotmail.com> | 2018-05-07 17:23:05 +0000 |
commit | 504fc2d0cded755406e25939f355fe4701849df8 (patch) | |
tree | 0c6381307673584859aebcebf5b23cacac861ac9 /clang/lib/CodeGen | |
parent | c90bb6d762da723ad41701fae82020b500332fc3 (diff) | |
download | bcm5719-llvm-504fc2d0cded755406e25939f355fe4701849df8.tar.gz bcm5719-llvm-504fc2d0cded755406e25939f355fe4701849df8.zip |
[OPENMP, NVPTX] Codegen for critical construct.
Added correct codegen for the critical construct on NVPTX devices.
llvm-svn: 331652
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp | 60 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h | 10 |
2 files changed, 70 insertions, 0 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp index 9e6f2b4b9a3..19b3147d26b 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -1837,6 +1837,66 @@ void CGOpenMPRuntimeNVPTX::emitSpmdParallelCall( emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, OutlinedFnArgs); } +void CGOpenMPRuntimeNVPTX::emitCriticalRegion( + CodeGenFunction &CGF, StringRef CriticalName, + const RegionCodeGenTy &CriticalOpGen, SourceLocation Loc, + const Expr *Hint) { + llvm::BasicBlock *LoopBB = CGF.createBasicBlock("omp.critical.loop"); + llvm::BasicBlock *TestBB = CGF.createBasicBlock("omp.critical.test"); + llvm::BasicBlock *SyncBB = CGF.createBasicBlock("omp.critical.sync"); + llvm::BasicBlock *BodyBB = CGF.createBasicBlock("omp.critical.body"); + llvm::BasicBlock *ExitBB = CGF.createBasicBlock("omp.critical.exit"); + + // Fetch team-local id of the thread. + llvm::Value *ThreadID = getNVPTXThreadID(CGF); + + // Get the width of the team. + llvm::Value *TeamWidth = getNVPTXNumThreads(CGF); + + // Initialize the counter variable for the loop. + QualType Int32Ty = + CGF.getContext().getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/0); + Address Counter = CGF.CreateMemTemp(Int32Ty, "critical_counter"); + LValue CounterLVal = CGF.MakeAddrLValue(Counter, Int32Ty); + CGF.EmitStoreOfScalar(llvm::Constant::getNullValue(CGM.Int32Ty), CounterLVal, + /*isInit=*/true); + + // Block checks if loop counter exceeds upper bound. + CGF.EmitBlock(LoopBB); + llvm::Value *CounterVal = CGF.EmitLoadOfScalar(CounterLVal, Loc); + llvm::Value *CmpLoopBound = CGF.Builder.CreateICmpSLT(CounterVal, TeamWidth); + CGF.Builder.CreateCondBr(CmpLoopBound, TestBB, ExitBB); + + // Block tests which single thread should execute region, and which threads + // should go straight to synchronisation point. + CGF.EmitBlock(TestBB); + CounterVal = CGF.EmitLoadOfScalar(CounterLVal, Loc); + llvm::Value *CmpThreadToCounter = + CGF.Builder.CreateICmpEQ(ThreadID, CounterVal); + CGF.Builder.CreateCondBr(CmpThreadToCounter, BodyBB, SyncBB); + + // Block emits the body of the critical region. + CGF.EmitBlock(BodyBB); + + // Output the critical statement. + CriticalOpGen(CGF); + + // After the body surrounded by the critical region, the single executing + // thread will jump to the synchronisation point. + // Block waits for all threads in current team to finish then increments the + // counter variable and returns to the loop. + CGF.EmitBlock(SyncBB); + getNVPTXCTABarrier(CGF); + + llvm::Value *IncCounterVal = + CGF.Builder.CreateNSWAdd(CounterVal, CGF.Builder.getInt32(1)); + CGF.EmitStoreOfScalar(IncCounterVal, CounterLVal); + CGF.EmitBranch(LoopBB); + + // Block that is reached when all threads in the team complete the region. + CGF.EmitBlock(ExitBB, /*IsFinished=*/true); +} + /// Cast value to the specified type. static llvm::Value *castValueToType(CodeGenFunction &CGF, llvm::Value *Val, QualType ValTy, QualType CastTy, diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h index a5f39c28a7a..ac8011dc790 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h +++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h @@ -250,6 +250,16 @@ public: ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) override; + /// Emits a critical region. + /// \param CriticalName Name of the critical region. + /// \param CriticalOpGen Generator for the statement associated with the given + /// critical region. + /// \param Hint Value of the 'hint' clause (optional). + void emitCriticalRegion(CodeGenFunction &CGF, StringRef CriticalName, + const RegionCodeGenTy &CriticalOpGen, + SourceLocation Loc, + const Expr *Hint = nullptr) override; + /// Emit a code for reduction clause. /// /// \param Privates List of private copies for original reduction arguments. |