diff options
| author | River Riddle <riverriddle@google.com> | 2020-01-13 15:46:40 -0800 |
|---|---|---|
| committer | River Riddle <riverriddle@google.com> | 2020-01-13 15:51:28 -0800 |
| commit | c7748404920b3674e79059cbbe73b6041a214444 (patch) | |
| tree | e7dc6a063e4f67e6b7b79f6d2fc71b067a315fa2 /mlir/lib | |
| parent | 6fca03f0cae77c275870c4569bfeeb7ca0f561a6 (diff) | |
| download | bcm5719-llvm-c7748404920b3674e79059cbbe73b6041a214444.tar.gz bcm5719-llvm-c7748404920b3674e79059cbbe73b6041a214444.zip | |
[mlir] Update the CallGraph for nested symbol references, and simplify CallableOpInterface
Summary:
This enables tracking calls that cross symbol table boundaries. It also simplifies some of the implementation details of CallableOpInterface, i.e. there can only be one region within the callable operation.
Depends On D72042
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D72043
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/CallGraph.cpp | 81 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Inliner.cpp | 20 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 2 |
3 files changed, 44 insertions, 59 deletions
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index e88b1201443..cc82776c349 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -74,67 +74,38 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { /// Recursively compute the callgraph edges for the given operation. Computed /// edges are placed into the given callgraph object. static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode); - -/// Compute the set of callgraph nodes that are created by regions nested within -/// 'op'. -static void computeCallables(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { - if (op->getNumRegions() == 0) + CallGraphNode *parentNode, bool resolveCalls) { + if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { + // If there is no parent node, we ignore this operation. Even if this + // operation was a call, there would be no callgraph node to attribute it + // to. + if (!resolveCalls || !parentNode) + return; + parentNode->addCallEdge( + cg.resolveCallable(call.getCallableForCallee(), op)); return; - if (auto callableOp = dyn_cast<CallableOpInterface>(op)) { - SmallVector<Region *, 1> callables; - callableOp.getCallableRegions(callables); - for (auto *callableRegion : callables) - cg.getOrAddNode(callableRegion, parentNode); } -} -/// Recursively compute the callgraph edges within the given region. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Region ®ion, CallGraph &cg, - CallGraphNode *parentNode) { - // Iterate over the nested operations twice: - /// One to fully create nodes in the for each callable region of a nested - /// operation; - for (auto &block : region) - for (auto &nested : block) - computeCallables(&nested, cg, parentNode); - - /// And another to recursively compute the callgraph. - for (auto &block : region) - for (auto &nested : block) - computeCallGraph(&nested, cg, parentNode); -} - -/// Recursively compute the callgraph edges for the given operation. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { // Compute the callgraph nodes and edges for each of the nested operations. - auto isCallable = isa<CallableOpInterface>(op); - for (auto ®ion : op->getRegions()) { - // Check to see if this region is a callable node, if so this is the parent - // node of the nested region. - CallGraphNode *nestedParentNode; - if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) - nestedParentNode = parentNode; - computeCallGraph(region, cg, nestedParentNode); + if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { + if (auto *callableRegion = callable.getCallableRegion()) + parentNode = cg.getOrAddNode(callableRegion, parentNode); + else + return; } - // If there is no parent node, we ignore this operation. Even if this - // operation was a call, there would be no callgraph node to attribute it to. - if (!parentNode) - return; - - // If this is a call operation, resolve the callee. - if (auto call = dyn_cast<CallOpInterface>(op)) - parentNode->addCallEdge( - cg.resolveCallable(call.getCallableForCallee(), op)); + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nested : block) + computeCallGraph(&nested, cg, parentNode, resolveCalls); } CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { - computeCallGraph(op, *this, /*parentNode=*/nullptr); + // Make two passes over the graph, one to compute the callables and one to + // resolve the calls. We split these up as we may have nested callable objects + // that need to be reserved before the calls. + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false); + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true); } /// Get or add a call graph node for the given region. @@ -175,9 +146,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // Get the callee operation from the callable. Operation *callee; if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) - // TODO(riverriddle) Support nested references. - callee = SymbolTable::lookupNearestSymbolFrom(from, - symbolRef.getRootReference()); + callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef); else callee = callable.get<Value>().getDefiningOp(); @@ -185,7 +154,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // called region from it. if (callee && callee->getNumRegions()) { if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { - if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) + if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; } } diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index b2cee7da083..d310316994f 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -86,8 +86,15 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast<CallOpInterface>(op)) { - CallGraphNode *node = - cg.resolveCallable(call.getCallableForCallee(), &op); + CallInterfaceCallable callable = call.getCallableForCallee(); + + // TODO(riverriddle) Support inlining nested call references. + if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { + if (!symRef.isa<FlatSymbolRefAttr>()) + continue; + } + + CallGraphNode *node = cg.resolveCallable(callable, &op); if (!node->isExternal()) calls.emplace_back(call, node); continue; @@ -274,6 +281,15 @@ struct InlinerPass : public OperationPass<InlinerPass> { CallGraph &cg = getAnalysis<CallGraph>(); auto *context = &getContext(); + // The inliner should only be run on operations that define a symbol table, + // as the callgraph will need to resolve references. + Operation *op = getOperation(); + if (!op->hasTrait<OpTrait::SymbolTable>()) { + op->emitOpError() << " was scheduled to run under the inliner, but does " + "not define a symbol table"; + return signalPassFailure(); + } + // Collect a set of canonicalization patterns to use when simplifying // callable regions within an SCC. OwningRewritePatternList canonPatterns; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 64591209dce..e91bdb71e9e 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -284,7 +284,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, if (src->empty()) return failure(); auto *entryBlock = &src->front(); - ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); + ArrayRef<Type> callableResultTypes = callable.getCallableResults(); // Make sure that the number of arguments and results matchup between the call // and the region. |

