summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/IPO/FunctionAttrs.cpp')
-rw-r--r--llvm/lib/Transforms/IPO/FunctionAttrs.cpp62
1 files changed, 37 insertions, 25 deletions
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index c321afb8aae..2eec4381a90 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -903,37 +903,49 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes,
return MadeChange;
}
-/// Remove the convergent attribute from all functions in the SCC if every
-/// callsite within the SCC is not convergent (except for calls to functions
-/// within the SCC). Returns true if changes were made.
+/// Removes convergent attributes where we can prove that none of the SCC's
+/// callees are themselves convergent. Returns true if successful at removing
+/// the attribute.
static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) {
- // No point checking if none of SCCNodes is convergent.
- if (!llvm::any_of(SCCNodes, [](Function *F) { return F->isConvergent(); }))
- return false;
+ // Determines whether a function can be made non-convergent, ignoring all
+ // other functions in SCC. (A function can *actually* be made non-convergent
+ // only if all functions in its SCC can be made convergent.)
+ auto CanRemoveConvergent = [&](Function *F) {
+ if (!F->isConvergent())
+ return true;
- // Can't remove convergent from function declarations.
- if (llvm::any_of(SCCNodes, [](Function *F) { return F->isDeclaration(); }))
- return false;
+ // Can't remove convergent from declarations.
+ if (F->isDeclaration())
+ return false;
- // Can't remove convergent if any of our functions has a convergent call to a
- // function not in the SCC.
- for (Function *F : SCCNodes)
- for (Instruction &I : instructions(*F)) {
- CallSite CS(&I);
- // Bail if is CS a convergent call to a function not in the SCC.
- if (CS && CS.isConvergent() &&
- SCCNodes.count(CS.getCalledFunction()) == 0)
- return false;
- }
+ for (Instruction &I : instructions(*F))
+ if (auto CS = CallSite(&I)) {
+ // Can't remove convergent if any of F's callees -- ignoring functions
+ // in the SCC itself -- are convergent. This needs to consider both
+ // function calls and intrinsic calls. We also assume indirect calls
+ // might call a convergent function.
+ // FIXME: We should revisit this when we put convergent onto calls
+ // instead of functions so that indirect calls which should be
+ // convergent are required to be marked as such.
+ Function *Callee = CS.getCalledFunction();
+ if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent()))
+ return false;
+ }
+
+ return true;
+ };
+
+ // We can remove the convergent attr from functions in the SCC if they all
+ // can be made non-convergent (because they call only non-convergent
+ // functions, other than each other).
+ if (!llvm::all_of(SCCNodes, CanRemoveConvergent))
+ return false;
- // If we got here, all of the calls the SCC makes to functions not in the SCC
- // are non-convergent. Therefore all of the SCC's functions can also be made
- // non-convergent. We'll remove the attr from the callsites in
- // InstCombineCalls.
+ // If we got here, all of the SCC's callees are non-convergent. Therefore all
+ // of the SCC's functions can be marked as non-convergent.
for (Function *F : SCCNodes) {
if (F->isConvergent())
- DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName()
- << "\n");
+ DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n");
F->setNotConvergent();
}
return true;
OpenPOWER on IntegriCloud