diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/CallGraph.cpp | 81 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Inliner.cpp | 20 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 2 |
3 files changed, 44 insertions, 59 deletions
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index e88b1201443..cc82776c349 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -74,67 +74,38 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { /// Recursively compute the callgraph edges for the given operation. Computed /// edges are placed into the given callgraph object. static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode); - -/// Compute the set of callgraph nodes that are created by regions nested within -/// 'op'. -static void computeCallables(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { - if (op->getNumRegions() == 0) + CallGraphNode *parentNode, bool resolveCalls) { + if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { + // If there is no parent node, we ignore this operation. Even if this + // operation was a call, there would be no callgraph node to attribute it + // to. + if (!resolveCalls || !parentNode) + return; + parentNode->addCallEdge( + cg.resolveCallable(call.getCallableForCallee(), op)); return; - if (auto callableOp = dyn_cast<CallableOpInterface>(op)) { - SmallVector<Region *, 1> callables; - callableOp.getCallableRegions(callables); - for (auto *callableRegion : callables) - cg.getOrAddNode(callableRegion, parentNode); } -} -/// Recursively compute the callgraph edges within the given region. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Region ®ion, CallGraph &cg, - CallGraphNode *parentNode) { - // Iterate over the nested operations twice: - /// One to fully create nodes in the for each callable region of a nested - /// operation; - for (auto &block : region) - for (auto &nested : block) - computeCallables(&nested, cg, parentNode); - - /// And another to recursively compute the callgraph. - for (auto &block : region) - for (auto &nested : block) - computeCallGraph(&nested, cg, parentNode); -} - -/// Recursively compute the callgraph edges for the given operation. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { // Compute the callgraph nodes and edges for each of the nested operations. - auto isCallable = isa<CallableOpInterface>(op); - for (auto ®ion : op->getRegions()) { - // Check to see if this region is a callable node, if so this is the parent - // node of the nested region. - CallGraphNode *nestedParentNode; - if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) - nestedParentNode = parentNode; - computeCallGraph(region, cg, nestedParentNode); + if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { + if (auto *callableRegion = callable.getCallableRegion()) + parentNode = cg.getOrAddNode(callableRegion, parentNode); + else + return; } - // If there is no parent node, we ignore this operation. Even if this - // operation was a call, there would be no callgraph node to attribute it to. - if (!parentNode) - return; - - // If this is a call operation, resolve the callee. - if (auto call = dyn_cast<CallOpInterface>(op)) - parentNode->addCallEdge( - cg.resolveCallable(call.getCallableForCallee(), op)); + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nested : block) + computeCallGraph(&nested, cg, parentNode, resolveCalls); } CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { - computeCallGraph(op, *this, /*parentNode=*/nullptr); + // Make two passes over the graph, one to compute the callables and one to + // resolve the calls. We split these up as we may have nested callable objects + // that need to be reserved before the calls. + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false); + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true); } /// Get or add a call graph node for the given region. @@ -175,9 +146,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // Get the callee operation from the callable. Operation *callee; if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) - // TODO(riverriddle) Support nested references. - callee = SymbolTable::lookupNearestSymbolFrom(from, - symbolRef.getRootReference()); + callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef); else callee = callable.get<Value>().getDefiningOp(); @@ -185,7 +154,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // called region from it. if (callee && callee->getNumRegions()) { if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { - if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) + if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; } } diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index b2cee7da083..d310316994f 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -86,8 +86,15 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast<CallOpInterface>(op)) { - CallGraphNode *node = - cg.resolveCallable(call.getCallableForCallee(), &op); + CallInterfaceCallable callable = call.getCallableForCallee(); + + // TODO(riverriddle) Support inlining nested call references. + if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { + if (!symRef.isa<FlatSymbolRefAttr>()) + continue; + } + + CallGraphNode *node = cg.resolveCallable(callable, &op); if (!node->isExternal()) calls.emplace_back(call, node); continue; @@ -274,6 +281,15 @@ struct InlinerPass : public OperationPass<InlinerPass> { CallGraph &cg = getAnalysis<CallGraph>(); auto *context = &getContext(); + // The inliner should only be run on operations that define a symbol table, + // as the callgraph will need to resolve references. + Operation *op = getOperation(); + if (!op->hasTrait<OpTrait::SymbolTable>()) { + op->emitOpError() << " was scheduled to run under the inliner, but does " + "not define a symbol table"; + return signalPassFailure(); + } + // Collect a set of canonicalization patterns to use when simplifying // callable regions within an SCC. OwningRewritePatternList canonPatterns; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 64591209dce..e91bdb71e9e 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -284,7 +284,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, if (src->empty()) return failure(); auto *entryBlock = &src->front(); - ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); + ArrayRef<Type> callableResultTypes = callable.getCallableResults(); // Make sure that the number of arguments and results matchup between the call // and the region. |

