diff options
Diffstat (limited to 'llvm')
-rw-r--r-- | llvm/lib/Transforms/IPO/FunctionAttrs.cpp | 62 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 11 | ||||
-rw-r--r-- | llvm/test/Transforms/FunctionAttrs/convergent.ll | 48 | ||||
-rw-r--r-- | llvm/test/Transforms/InstCombine/convergent.ll | 33 |
4 files changed, 56 insertions, 98 deletions
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index c321afb8aae..2eec4381a90 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -903,37 +903,49 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes, return MadeChange; } -/// Remove the convergent attribute from all functions in the SCC if every -/// callsite within the SCC is not convergent (except for calls to functions -/// within the SCC). Returns true if changes were made. +/// Removes convergent attributes where we can prove that none of the SCC's +/// callees are themselves convergent. Returns true if successful at removing +/// the attribute. static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) { - // No point checking if none of SCCNodes is convergent. - if (!llvm::any_of(SCCNodes, [](Function *F) { return F->isConvergent(); })) - return false; + // Determines whether a function can be made non-convergent, ignoring all + // other functions in SCC. (A function can *actually* be made non-convergent + // only if all functions in its SCC can be made convergent.) + auto CanRemoveConvergent = [&](Function *F) { + if (!F->isConvergent()) + return true; - // Can't remove convergent from function declarations. - if (llvm::any_of(SCCNodes, [](Function *F) { return F->isDeclaration(); })) - return false; + // Can't remove convergent from declarations. + if (F->isDeclaration()) + return false; - // Can't remove convergent if any of our functions has a convergent call to a - // function not in the SCC. - for (Function *F : SCCNodes) - for (Instruction &I : instructions(*F)) { - CallSite CS(&I); - // Bail if is CS a convergent call to a function not in the SCC. - if (CS && CS.isConvergent() && - SCCNodes.count(CS.getCalledFunction()) == 0) - return false; - } + for (Instruction &I : instructions(*F)) + if (auto CS = CallSite(&I)) { + // Can't remove convergent if any of F's callees -- ignoring functions + // in the SCC itself -- are convergent. This needs to consider both + // function calls and intrinsic calls. We also assume indirect calls + // might call a convergent function. + // FIXME: We should revisit this when we put convergent onto calls + // instead of functions so that indirect calls which should be + // convergent are required to be marked as such. + Function *Callee = CS.getCalledFunction(); + if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent())) + return false; + } + + return true; + }; + + // We can remove the convergent attr from functions in the SCC if they all + // can be made non-convergent (because they call only non-convergent + // functions, other than each other). + if (!llvm::all_of(SCCNodes, CanRemoveConvergent)) + return false; - // If we got here, all of the calls the SCC makes to functions not in the SCC - // are non-convergent. Therefore all of the SCC's functions can also be made - // non-convergent. We'll remove the attr from the callsites in - // InstCombineCalls. + // If we got here, all of the SCC's callees are non-convergent. Therefore all + // of the SCC's functions can be marked as non-convergent. for (Function *F : SCCNodes) { if (F->isConvergent()) - DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName() - << "\n"); + DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n"); F->setNotConvergent(); } return true; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 71199a41ed3..249553aed27 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2070,15 +2070,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { if (!isa<Function>(Callee) && transformConstExprCastCall(CS)) return nullptr; - if (Function *CalleeF = dyn_cast<Function>(Callee)) { - // Remove the convergent attr on calls when the callee is not convergent. - if (CS.isConvergent() && !CalleeF->isConvergent()) { - DEBUG(dbgs() << "Removing convergent attr from instr " - << CS.getInstruction() << "\n"); - CS.setNotConvergent(); - return CS.getInstruction(); - } - + if (Function *CalleeF = dyn_cast<Function>(Callee)) // If the call and callee calling conventions don't match, this call must // be unreachable, as the call is undefined. if (CalleeF->getCallingConv() != CS.getCallingConv() && @@ -2103,7 +2095,6 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { Constant::getNullValue(CalleeF->getType())); return nullptr; } - } if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { // If CS does not return void then replaceAllUsesWith undef. diff --git a/llvm/test/Transforms/FunctionAttrs/convergent.ll b/llvm/test/Transforms/FunctionAttrs/convergent.ll index bc21d85ec22..46370d7bf30 100644 --- a/llvm/test/Transforms/FunctionAttrs/convergent.ll +++ b/llvm/test/Transforms/FunctionAttrs/convergent.ll @@ -1,4 +1,4 @@ -; RUN: opt -functionattrs -S < %s | FileCheck %s +; RUN: opt < %s -basicaa -functionattrs -rpo-functionattrs -S | FileCheck %s ; CHECK: Function Attrs ; CHECK-NOT: convergent @@ -24,37 +24,16 @@ declare i32 @k() convergent ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @extern() define i32 @extern() convergent { - %a = call i32 @k() convergent - ret i32 %a -} - -; Convergent should not be removed on the function here. Although the call is -; not explicitly convergent, it picks up the convergent attr from the callee. -; -; CHECK: Function Attrs -; CHECK-SAME: convergent -; CHECK-NEXT: define i32 @extern_non_convergent_call() -define i32 @extern_non_convergent_call() convergent { %a = call i32 @k() ret i32 %a } ; CHECK: Function Attrs ; CHECK-SAME: convergent -; CHECK-NEXT: define i32 @indirect_convergent_call( -define i32 @indirect_convergent_call(i32 ()* %f) convergent { - %a = call i32 %f() convergent - ret i32 %a -} -; Give indirect_non_convergent_call the norecurse attribute so we get a -; "Function Attrs" comment in the output. -; -; CHECK: Function Attrs -; CHECK-NOT: convergent -; CHECK-NEXT: define i32 @indirect_non_convergent_call( -define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse { - %a = call i32 %f() - ret i32 %a +; CHECK-NEXT: define i32 @call_extern() +define i32 @call_extern() convergent { + %a = call i32 @extern() + ret i32 %a } ; CHECK: Function Attrs @@ -66,16 +45,25 @@ declare void @llvm.cuda.syncthreads() convergent ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @intrinsic() define i32 @intrinsic() convergent { - ; Implicitly convergent, because the intrinsic is convergent. call void @llvm.cuda.syncthreads() ret i32 0 } +@xyz = global i32 ()* null +; CHECK: Function Attrs +; CHECK-SAME: convergent +; CHECK-NEXT: define i32 @functionptr() +define i32 @functionptr() convergent { + %1 = load i32 ()*, i32 ()** @xyz + %2 = call i32 %1() + ret i32 %2 +} + ; CHECK: Function Attrs ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @recursive1() define i32 @recursive1() convergent { - %a = call i32 @recursive2() convergent + %a = call i32 @recursive2() ret i32 %a } @@ -83,7 +71,7 @@ define i32 @recursive1() convergent { ; CHECK-NOT: convergent ; CHECK-NEXT: define i32 @recursive2() define i32 @recursive2() convergent { - %a = call i32 @recursive1() convergent + %a = call i32 @recursive1() ret i32 %a } @@ -91,7 +79,7 @@ define i32 @recursive2() convergent { ; CHECK-SAME: convergent ; CHECK-NEXT: define i32 @noopt() define i32 @noopt() convergent optnone noinline { - %a = call i32 @noopt_friend() convergent + %a = call i32 @noopt_friend() ret i32 0 } diff --git a/llvm/test/Transforms/InstCombine/convergent.ll b/llvm/test/Transforms/InstCombine/convergent.ll deleted file mode 100644 index 4ed40d81bad..00000000000 --- a/llvm/test/Transforms/InstCombine/convergent.ll +++ /dev/null @@ -1,33 +0,0 @@ -; RUN: opt -instcombine -S < %s | FileCheck %s - -declare i32 @k() convergent -declare i32 @f() - -define i32 @extern() { - ; Convergent attr shouldn't be removed here; k is convergent. - ; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]] - %a = call i32 @k() convergent - ret i32 %a -} - -define i32 @extern_no_attr() { - ; Convergent attr shouldn't be added here, even though k is convergent. - ; CHECK: call i32 @k(){{$}} - %a = call i32 @k() - ret i32 %a -} - -define i32 @no_extern() { - ; Convergent should be removed here, as the target is convergent. - ; CHECK: call i32 @f(){{$}} - %a = call i32 @f() convergent - ret i32 %a -} - -define i32 @indirect_call(i32 ()* %f) { - ; CHECK call i32 %f() [[CONVERGENT_ATTR]] - %a = call i32 %f() convergent - ret i32 %a -} - -; CHECK: [[CONVERGENT_ATTR]] = { convergent } |