diff options
author | River Riddle <riverriddle@google.com> | 2020-01-13 15:46:40 -0800 |
---|---|---|
committer | River Riddle <riverriddle@google.com> | 2020-01-13 15:51:28 -0800 |
commit | c7748404920b3674e79059cbbe73b6041a214444 (patch) | |
tree | e7dc6a063e4f67e6b7b79f6d2fc71b067a315fa2 /mlir | |
parent | 6fca03f0cae77c275870c4569bfeeb7ca0f561a6 (diff) | |
download | bcm5719-llvm-c7748404920b3674e79059cbbe73b6041a214444.tar.gz bcm5719-llvm-c7748404920b3674e79059cbbe73b6041a214444.zip |
[mlir] Update the CallGraph for nested symbol references, and simplify CallableOpInterface
Summary:
This enables tracking calls that cross symbol table boundaries. It also simplifies some of the implementation details of CallableOpInterface, i.e. there can only be one region within the callable operation.
Depends On D72042
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D72043
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Analysis/CallInterfaces.td | 22 | ||||
-rw-r--r-- | mlir/include/mlir/IR/Function.h | 25 | ||||
-rw-r--r-- | mlir/lib/Analysis/CallGraph.cpp | 81 | ||||
-rw-r--r-- | mlir/lib/Transforms/Inliner.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 2 | ||||
-rw-r--r-- | mlir/test/Analysis/test-callgraph.mlir | 21 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestOps.td | 11 |
7 files changed, 82 insertions, 100 deletions
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td index 3e5b599baf8..bde5a2fc276 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.td +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -54,29 +54,23 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { be a target for a call-like operation (those providing the CallOpInterface above). These operations may be traditional functional operation `func @foo(...)`, as well as function producing operations - `%foo = dialect.create_function(...)`. These operations may produce multiple - callable regions, or subroutines. + `%foo = dialect.create_function(...)`. These operations may only contain a + single region, or subroutine. }]; let methods = [ InterfaceMethod<[{ - Returns a region on the current operation that the given callable refers - to. This may return null in the case of an external callable object, - e.g. an external function. + Returns the region on the current operation that is callable. This may + return null in the case of an external callable object, e.g. an external + function. }], - "Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable) + "Region *", "getCallableRegion" >, InterfaceMethod<[{ - Returns all of the callable regions of this operation. - }], - "void", "getCallableRegions", - (ins "SmallVectorImpl<Region *> &":$callables) - >, - InterfaceMethod<[{ - Returns the results types that the given callable region produces when + Returns the results types that the callable region produces when executed. }], - "ArrayRef<Type>", "getCallableResults", (ins "Region *":$callable) + "ArrayRef<Type>", "getCallableResults" >, ]; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 3f788bbeeba..5323b352a89 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -122,26 +122,13 @@ public: // CallableOpInterface //===--------------------------------------------------------------------===// - /// Returns a region on the current operation that the given callable refers - /// to. This may return null in the case of an external callable object, e.g. - /// an external function. - Region *getCallableRegion(CallInterfaceCallable callable) { - assert(callable.get<SymbolRefAttr>().getLeafReference() == getName()); - return isExternal() ? nullptr : &getBody(); - } - - /// Returns all of the callable regions of this operation. - void getCallableRegions(SmallVectorImpl<Region *> &callables) { - if (!isExternal()) - callables.push_back(&getBody()); - } + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } - /// Returns the results types that the given callable region produces when - /// executed. - ArrayRef<Type> getCallableResults(Region *region) { - assert(!isExternal() && region == &getBody() && "invalid callable"); - return getType().getResults(); - } + /// Returns the results types that the callable region produces when executed. + ArrayRef<Type> getCallableResults() { return getType().getResults(); } private: // This trait needs access to the hooks defined below. diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index e88b1201443..cc82776c349 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -74,67 +74,38 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { /// Recursively compute the callgraph edges for the given operation. Computed /// edges are placed into the given callgraph object. static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode); - -/// Compute the set of callgraph nodes that are created by regions nested within -/// 'op'. -static void computeCallables(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { - if (op->getNumRegions() == 0) + CallGraphNode *parentNode, bool resolveCalls) { + if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { + // If there is no parent node, we ignore this operation. Even if this + // operation was a call, there would be no callgraph node to attribute it + // to. + if (!resolveCalls || !parentNode) + return; + parentNode->addCallEdge( + cg.resolveCallable(call.getCallableForCallee(), op)); return; - if (auto callableOp = dyn_cast<CallableOpInterface>(op)) { - SmallVector<Region *, 1> callables; - callableOp.getCallableRegions(callables); - for (auto *callableRegion : callables) - cg.getOrAddNode(callableRegion, parentNode); } -} -/// Recursively compute the callgraph edges within the given region. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Region ®ion, CallGraph &cg, - CallGraphNode *parentNode) { - // Iterate over the nested operations twice: - /// One to fully create nodes in the for each callable region of a nested - /// operation; - for (auto &block : region) - for (auto &nested : block) - computeCallables(&nested, cg, parentNode); - - /// And another to recursively compute the callgraph. - for (auto &block : region) - for (auto &nested : block) - computeCallGraph(&nested, cg, parentNode); -} - -/// Recursively compute the callgraph edges for the given operation. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { // Compute the callgraph nodes and edges for each of the nested operations. - auto isCallable = isa<CallableOpInterface>(op); - for (auto ®ion : op->getRegions()) { - // Check to see if this region is a callable node, if so this is the parent - // node of the nested region. - CallGraphNode *nestedParentNode; - if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) - nestedParentNode = parentNode; - computeCallGraph(region, cg, nestedParentNode); + if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { + if (auto *callableRegion = callable.getCallableRegion()) + parentNode = cg.getOrAddNode(callableRegion, parentNode); + else + return; } - // If there is no parent node, we ignore this operation. Even if this - // operation was a call, there would be no callgraph node to attribute it to. - if (!parentNode) - return; - - // If this is a call operation, resolve the callee. - if (auto call = dyn_cast<CallOpInterface>(op)) - parentNode->addCallEdge( - cg.resolveCallable(call.getCallableForCallee(), op)); + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nested : block) + computeCallGraph(&nested, cg, parentNode, resolveCalls); } CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { - computeCallGraph(op, *this, /*parentNode=*/nullptr); + // Make two passes over the graph, one to compute the callables and one to + // resolve the calls. We split these up as we may have nested callable objects + // that need to be reserved before the calls. + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false); + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true); } /// Get or add a call graph node for the given region. @@ -175,9 +146,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // Get the callee operation from the callable. Operation *callee; if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) - // TODO(riverriddle) Support nested references. - callee = SymbolTable::lookupNearestSymbolFrom(from, - symbolRef.getRootReference()); + callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef); else callee = callable.get<Value>().getDefiningOp(); @@ -185,7 +154,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // called region from it. if (callee && callee->getNumRegions()) { if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { - if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) + if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; } } diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index b2cee7da083..d310316994f 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -86,8 +86,15 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast<CallOpInterface>(op)) { - CallGraphNode *node = - cg.resolveCallable(call.getCallableForCallee(), &op); + CallInterfaceCallable callable = call.getCallableForCallee(); + + // TODO(riverriddle) Support inlining nested call references. + if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { + if (!symRef.isa<FlatSymbolRefAttr>()) + continue; + } + + CallGraphNode *node = cg.resolveCallable(callable, &op); if (!node->isExternal()) calls.emplace_back(call, node); continue; @@ -274,6 +281,15 @@ struct InlinerPass : public OperationPass<InlinerPass> { CallGraph &cg = getAnalysis<CallGraph>(); auto *context = &getContext(); + // The inliner should only be run on operations that define a symbol table, + // as the callgraph will need to resolve references. + Operation *op = getOperation(); + if (!op->hasTrait<OpTrait::SymbolTable>()) { + op->emitOpError() << " was scheduled to run under the inliner, but does " + "not define a symbol table"; + return signalPassFailure(); + } + // Collect a set of canonicalization patterns to use when simplifying // callable regions within an SCC. OwningRewritePatternList canonPatterns; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 64591209dce..e91bdb71e9e 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -284,7 +284,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, if (src->empty()) return failure(); auto *entryBlock = &src->front(); - ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); + ArrayRef<Type> callableResultTypes = callable.getCallableResults(); // Make sure that the number of arguments and results matchup between the call // and the region. diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir index 39e4fb8ba27..8c295ff248e 100644 --- a/mlir/test/Analysis/test-callgraph.mlir +++ b/mlir/test/Analysis/test-callgraph.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-print-callgraph 2>&1 | FileCheck %s --dump-input-on-failure +// RUN: mlir-opt %s -test-print-callgraph -split-input-file 2>&1 | FileCheck %s --dump-input-on-failure // CHECK-LABEL: Testing : "simple" module attributes {test.name = "simple"} { @@ -50,3 +50,22 @@ module attributes {test.name = "simple"} { return } } + +// ----- + +// CHECK-LABEL: Testing : "nested" +module attributes {test.name = "nested"} { + module @nested_module { + // CHECK: Node{{.*}}func_a + func @func_a() { + return + } + } + + // CHECK: Node{{.*}}func_b + // CHECK: Call-Edge{{.*}}func_a + func @func_b() { + "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> () + return + } +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 36eb545a39e..f10991dfe5b 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -230,7 +230,7 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> { def ConversionCallOp : TEST_Op<"conversion_call_op", [CallOpInterface]> { - let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee); + let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee); let results = (outs Variadic<AnyType>); let extraClassDeclaration = [{ @@ -239,7 +239,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op", /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { - return getAttrOfType<FlatSymbolRefAttr>("callee"); + return getAttrOfType<SymbolRefAttr>("callee"); } }]; } @@ -250,11 +250,8 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op", let results = (outs FunctionType); let extraClassDeclaration = [{ - Region *getCallableRegion(CallInterfaceCallable) { return &body(); } - void getCallableRegions(SmallVectorImpl<Region *> &callables) { - callables.push_back(&body()); - } - ArrayRef<Type> getCallableResults(Region *) { + Region *getCallableRegion() { return &body(); } + ArrayRef<Type> getCallableResults() { return getType().cast<FunctionType>().getResults(); } }]; |