summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp26
-rw-r--r--clang/test/OpenMP/nvptx_parallel_codegen.cpp3
2 files changed, 26 insertions, 3 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
index 284d14b9958..6c908412e0c 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -107,6 +107,10 @@ enum OpenMPRTLFunctionNVPTX {
/// Call to void __kmpc_barrier_simple_spmd(ident_t *loc, kmp_int32
/// global_tid);
OMPRTL__kmpc_barrier_simple_spmd,
+ /// Call to int32_t __kmpc_warp_active_thread_mask(void);
+ OMPRTL_NVPTX__kmpc_warp_active_thread_mask,
+ /// Call to void __kmpc_syncwarp(int32_t Mask);
+ OMPRTL_NVPTX__kmpc_syncwarp,
};
/// Pre(post)-action for different OpenMP constructs specialized for NVPTX.
@@ -1794,6 +1798,20 @@ CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
->addFnAttr(llvm::Attribute::Convergent);
break;
}
+ case OMPRTL_NVPTX__kmpc_warp_active_thread_mask: {
+ // Build int32_t __kmpc_warp_active_thread_mask(void);
+ auto *FnTy =
+ llvm::FunctionType::get(CGM.Int32Ty, llvm::None, /*isVarArg=*/false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_warp_active_thread_mask");
+ break;
+ }
+ case OMPRTL_NVPTX__kmpc_syncwarp: {
+ // Build void __kmpc_syncwarp(kmp_int32 Mask);
+ auto *FnTy =
+ llvm::FunctionType::get(CGM.VoidTy, CGM.Int32Ty, /*isVarArg=*/false);
+ RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_syncwarp");
+ break;
+ }
}
return RTLFn;
}
@@ -2700,6 +2718,9 @@ void CGOpenMPRuntimeNVPTX::emitCriticalRegion(
llvm::BasicBlock *BodyBB = CGF.createBasicBlock("omp.critical.body");
llvm::BasicBlock *ExitBB = CGF.createBasicBlock("omp.critical.exit");
+ // Get the mask of active threads in the warp.
+ llvm::Value *Mask = CGF.EmitRuntimeCall(
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_warp_active_thread_mask));
// Fetch team-local id of the thread.
llvm::Value *ThreadID = getNVPTXThreadID(CGF);
@@ -2740,8 +2761,9 @@ void CGOpenMPRuntimeNVPTX::emitCriticalRegion(
// Block waits for all threads in current team to finish then increments the
// counter variable and returns to the loop.
CGF.EmitBlock(SyncBB);
- emitBarrierCall(CGF, Loc, OMPD_unknown, /*EmitChecks=*/false,
- /*ForceSimpleCall=*/true);
+ // Reconverge active threads in the warp.
+ (void)CGF.EmitRuntimeCall(
+ createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_syncwarp), Mask);
llvm::Value *IncCounterVal =
CGF.Builder.CreateNSWAdd(CounterVal, CGF.Builder.getInt32(1));
diff --git a/clang/test/OpenMP/nvptx_parallel_codegen.cpp b/clang/test/OpenMP/nvptx_parallel_codegen.cpp
index 155ffc5100e..32061bf7386 100644
--- a/clang/test/OpenMP/nvptx_parallel_codegen.cpp
+++ b/clang/test/OpenMP/nvptx_parallel_codegen.cpp
@@ -343,6 +343,7 @@ int bar(int n){
// CHECK-LABEL: define internal void @{{.+}}(i32* noalias %{{.+}}, i32* noalias %{{.+}}, i32* dereferenceable{{.*}})
// CHECK: [[CC:%.+]] = alloca i32,
+// CHECK: [[MASK:%.+]] = call i32 @__kmpc_warp_active_thread_mask()
// CHECK: [[TID:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
// CHECK: [[NUM_THREADS:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
// CHECK: store i32 0, i32* [[CC]],
@@ -362,7 +363,7 @@ int bar(int n){
// CHECK: store i32
// CHECK: call void @__kmpc_end_critical(
-// CHECK: call void @__kmpc_barrier(%struct.ident_t* @{{.+}}, i32 %{{.+}})
+// CHECK: call void @__kmpc_syncwarp(i32 [[MASK]])
// CHECK: [[NEW_CC_VAL:%.+]] = add nsw i32 [[CC_VAL]], 1
// CHECK: store i32 [[NEW_CC_VAL]], i32* [[CC]],
// CHECK: br label
OpenPOWER on IntegriCloud