summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/CallGraph.cpp81
-rw-r--r--mlir/lib/Transforms/Inliner.cpp20
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp2
3 files changed, 44 insertions, 59 deletions
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index e88b1201443..cc82776c349 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -74,67 +74,38 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
/// Recursively compute the callgraph edges for the given operation. Computed
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
- CallGraphNode *parentNode);
-
-/// Compute the set of callgraph nodes that are created by regions nested within
-/// 'op'.
-static void computeCallables(Operation *op, CallGraph &cg,
- CallGraphNode *parentNode) {
- if (op->getNumRegions() == 0)
+ CallGraphNode *parentNode, bool resolveCalls) {
+ if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
+ // If there is no parent node, we ignore this operation. Even if this
+ // operation was a call, there would be no callgraph node to attribute it
+ // to.
+ if (!resolveCalls || !parentNode)
+ return;
+ parentNode->addCallEdge(
+ cg.resolveCallable(call.getCallableForCallee(), op));
return;
- if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
- SmallVector<Region *, 1> callables;
- callableOp.getCallableRegions(callables);
- for (auto *callableRegion : callables)
- cg.getOrAddNode(callableRegion, parentNode);
}
-}
-/// Recursively compute the callgraph edges within the given region. Computed
-/// edges are placed into the given callgraph object.
-static void computeCallGraph(Region &region, CallGraph &cg,
- CallGraphNode *parentNode) {
- // Iterate over the nested operations twice:
- /// One to fully create nodes in the for each callable region of a nested
- /// operation;
- for (auto &block : region)
- for (auto &nested : block)
- computeCallables(&nested, cg, parentNode);
-
- /// And another to recursively compute the callgraph.
- for (auto &block : region)
- for (auto &nested : block)
- computeCallGraph(&nested, cg, parentNode);
-}
-
-/// Recursively compute the callgraph edges for the given operation. Computed
-/// edges are placed into the given callgraph object.
-static void computeCallGraph(Operation *op, CallGraph &cg,
- CallGraphNode *parentNode) {
// Compute the callgraph nodes and edges for each of the nested operations.
- auto isCallable = isa<CallableOpInterface>(op);
- for (auto &region : op->getRegions()) {
- // Check to see if this region is a callable node, if so this is the parent
- // node of the nested region.
- CallGraphNode *nestedParentNode;
- if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
- nestedParentNode = parentNode;
- computeCallGraph(region, cg, nestedParentNode);
+ if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
+ if (auto *callableRegion = callable.getCallableRegion())
+ parentNode = cg.getOrAddNode(callableRegion, parentNode);
+ else
+ return;
}
- // If there is no parent node, we ignore this operation. Even if this
- // operation was a call, there would be no callgraph node to attribute it to.
- if (!parentNode)
- return;
-
- // If this is a call operation, resolve the callee.
- if (auto call = dyn_cast<CallOpInterface>(op))
- parentNode->addCallEdge(
- cg.resolveCallable(call.getCallableForCallee(), op));
+ for (Region &region : op->getRegions())
+ for (Block &block : region)
+ for (Operation &nested : block)
+ computeCallGraph(&nested, cg, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
- computeCallGraph(op, *this, /*parentNode=*/nullptr);
+ // Make two passes over the graph, one to compute the callables and one to
+ // resolve the calls. We split these up as we may have nested callable objects
+ // that need to be reserved before the calls.
+ computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
+ computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
}
/// Get or add a call graph node for the given region.
@@ -175,9 +146,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
// Get the callee operation from the callable.
Operation *callee;
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
- // TODO(riverriddle) Support nested references.
- callee = SymbolTable::lookupNearestSymbolFrom(from,
- symbolRef.getRootReference());
+ callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
else
callee = callable.get<Value>().getDefiningOp();
@@ -185,7 +154,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
// called region from it.
if (callee && callee->getNumRegions()) {
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
- if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
+ if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
}
}
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index b2cee7da083..d310316994f 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -86,8 +86,15 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
while (!worklist.empty()) {
for (Operation &op : *worklist.pop_back_val()) {
if (auto call = dyn_cast<CallOpInterface>(op)) {
- CallGraphNode *node =
- cg.resolveCallable(call.getCallableForCallee(), &op);
+ CallInterfaceCallable callable = call.getCallableForCallee();
+
+ // TODO(riverriddle) Support inlining nested call references.
+ if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
+ if (!symRef.isa<FlatSymbolRefAttr>())
+ continue;
+ }
+
+ CallGraphNode *node = cg.resolveCallable(callable, &op);
if (!node->isExternal())
calls.emplace_back(call, node);
continue;
@@ -274,6 +281,15 @@ struct InlinerPass : public OperationPass<InlinerPass> {
CallGraph &cg = getAnalysis<CallGraph>();
auto *context = &getContext();
+ // The inliner should only be run on operations that define a symbol table,
+ // as the callgraph will need to resolve references.
+ Operation *op = getOperation();
+ if (!op->hasTrait<OpTrait::SymbolTable>()) {
+ op->emitOpError() << " was scheduled to run under the inliner, but does "
+ "not define a symbol table";
+ return signalPassFailure();
+ }
+
// Collect a set of canonicalization patterns to use when simplifying
// callable regions within an SCC.
OwningRewritePatternList canonPatterns;
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 64591209dce..e91bdb71e9e 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -284,7 +284,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
if (src->empty())
return failure();
auto *entryBlock = &src->front();
- ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
+ ArrayRef<Type> callableResultTypes = callable.getCallableResults();
// Make sure that the number of arguments and results matchup between the call
// and the region.
OpenPOWER on IntegriCloud