summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Analysis/CallInterfaces.td22
-rw-r--r--mlir/include/mlir/IR/Function.h25
-rw-r--r--mlir/lib/Analysis/CallGraph.cpp81
-rw-r--r--mlir/lib/Transforms/Inliner.cpp20
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp2
-rw-r--r--mlir/test/Analysis/test-callgraph.mlir21
-rw-r--r--mlir/test/lib/TestDialect/TestOps.td11
7 files changed, 82 insertions, 100 deletions
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td
index 3e5b599baf8..bde5a2fc276 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.td
+++ b/mlir/include/mlir/Analysis/CallInterfaces.td
@@ -54,29 +54,23 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
be a target for a call-like operation (those providing the CallOpInterface
above). These operations may be traditional functional operation
`func @foo(...)`, as well as function producing operations
- `%foo = dialect.create_function(...)`. These operations may produce multiple
- callable regions, or subroutines.
+ `%foo = dialect.create_function(...)`. These operations may only contain a
+ single region, or subroutine.
}];
let methods = [
InterfaceMethod<[{
- Returns a region on the current operation that the given callable refers
- to. This may return null in the case of an external callable object,
- e.g. an external function.
+ Returns the region on the current operation that is callable. This may
+ return null in the case of an external callable object, e.g. an external
+ function.
}],
- "Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable)
+ "Region *", "getCallableRegion"
>,
InterfaceMethod<[{
- Returns all of the callable regions of this operation.
- }],
- "void", "getCallableRegions",
- (ins "SmallVectorImpl<Region *> &":$callables)
- >,
- InterfaceMethod<[{
- Returns the results types that the given callable region produces when
+ Returns the results types that the callable region produces when
executed.
}],
- "ArrayRef<Type>", "getCallableResults", (ins "Region *":$callable)
+ "ArrayRef<Type>", "getCallableResults"
>,
];
}
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 3f788bbeeba..5323b352a89 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -122,26 +122,13 @@ public:
// CallableOpInterface
//===--------------------------------------------------------------------===//
- /// Returns a region on the current operation that the given callable refers
- /// to. This may return null in the case of an external callable object, e.g.
- /// an external function.
- Region *getCallableRegion(CallInterfaceCallable callable) {
- assert(callable.get<SymbolRefAttr>().getLeafReference() == getName());
- return isExternal() ? nullptr : &getBody();
- }
-
- /// Returns all of the callable regions of this operation.
- void getCallableRegions(SmallVectorImpl<Region *> &callables) {
- if (!isExternal())
- callables.push_back(&getBody());
- }
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
- /// Returns the results types that the given callable region produces when
- /// executed.
- ArrayRef<Type> getCallableResults(Region *region) {
- assert(!isExternal() && region == &getBody() && "invalid callable");
- return getType().getResults();
- }
+ /// Returns the results types that the callable region produces when executed.
+ ArrayRef<Type> getCallableResults() { return getType().getResults(); }
private:
// This trait needs access to the hooks defined below.
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index e88b1201443..cc82776c349 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -74,67 +74,38 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
/// Recursively compute the callgraph edges for the given operation. Computed
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
- CallGraphNode *parentNode);
-
-/// Compute the set of callgraph nodes that are created by regions nested within
-/// 'op'.
-static void computeCallables(Operation *op, CallGraph &cg,
- CallGraphNode *parentNode) {
- if (op->getNumRegions() == 0)
+ CallGraphNode *parentNode, bool resolveCalls) {
+ if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
+ // If there is no parent node, we ignore this operation. Even if this
+ // operation was a call, there would be no callgraph node to attribute it
+ // to.
+ if (!resolveCalls || !parentNode)
+ return;
+ parentNode->addCallEdge(
+ cg.resolveCallable(call.getCallableForCallee(), op));
return;
- if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
- SmallVector<Region *, 1> callables;
- callableOp.getCallableRegions(callables);
- for (auto *callableRegion : callables)
- cg.getOrAddNode(callableRegion, parentNode);
}
-}
-/// Recursively compute the callgraph edges within the given region. Computed
-/// edges are placed into the given callgraph object.
-static void computeCallGraph(Region &region, CallGraph &cg,
- CallGraphNode *parentNode) {
- // Iterate over the nested operations twice:
- /// One to fully create nodes in the for each callable region of a nested
- /// operation;
- for (auto &block : region)
- for (auto &nested : block)
- computeCallables(&nested, cg, parentNode);
-
- /// And another to recursively compute the callgraph.
- for (auto &block : region)
- for (auto &nested : block)
- computeCallGraph(&nested, cg, parentNode);
-}
-
-/// Recursively compute the callgraph edges for the given operation. Computed
-/// edges are placed into the given callgraph object.
-static void computeCallGraph(Operation *op, CallGraph &cg,
- CallGraphNode *parentNode) {
// Compute the callgraph nodes and edges for each of the nested operations.
- auto isCallable = isa<CallableOpInterface>(op);
- for (auto &region : op->getRegions()) {
- // Check to see if this region is a callable node, if so this is the parent
- // node of the nested region.
- CallGraphNode *nestedParentNode;
- if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
- nestedParentNode = parentNode;
- computeCallGraph(region, cg, nestedParentNode);
+ if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
+ if (auto *callableRegion = callable.getCallableRegion())
+ parentNode = cg.getOrAddNode(callableRegion, parentNode);
+ else
+ return;
}
- // If there is no parent node, we ignore this operation. Even if this
- // operation was a call, there would be no callgraph node to attribute it to.
- if (!parentNode)
- return;
-
- // If this is a call operation, resolve the callee.
- if (auto call = dyn_cast<CallOpInterface>(op))
- parentNode->addCallEdge(
- cg.resolveCallable(call.getCallableForCallee(), op));
+ for (Region &region : op->getRegions())
+ for (Block &block : region)
+ for (Operation &nested : block)
+ computeCallGraph(&nested, cg, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
- computeCallGraph(op, *this, /*parentNode=*/nullptr);
+ // Make two passes over the graph, one to compute the callables and one to
+ // resolve the calls. We split these up as we may have nested callable objects
+ // that need to be reserved before the calls.
+ computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
+ computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
}
/// Get or add a call graph node for the given region.
@@ -175,9 +146,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
// Get the callee operation from the callable.
Operation *callee;
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
- // TODO(riverriddle) Support nested references.
- callee = SymbolTable::lookupNearestSymbolFrom(from,
- symbolRef.getRootReference());
+ callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
else
callee = callable.get<Value>().getDefiningOp();
@@ -185,7 +154,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
// called region from it.
if (callee && callee->getNumRegions()) {
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
- if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
+ if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
}
}
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index b2cee7da083..d310316994f 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -86,8 +86,15 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
while (!worklist.empty()) {
for (Operation &op : *worklist.pop_back_val()) {
if (auto call = dyn_cast<CallOpInterface>(op)) {
- CallGraphNode *node =
- cg.resolveCallable(call.getCallableForCallee(), &op);
+ CallInterfaceCallable callable = call.getCallableForCallee();
+
+ // TODO(riverriddle) Support inlining nested call references.
+ if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
+ if (!symRef.isa<FlatSymbolRefAttr>())
+ continue;
+ }
+
+ CallGraphNode *node = cg.resolveCallable(callable, &op);
if (!node->isExternal())
calls.emplace_back(call, node);
continue;
@@ -274,6 +281,15 @@ struct InlinerPass : public OperationPass<InlinerPass> {
CallGraph &cg = getAnalysis<CallGraph>();
auto *context = &getContext();
+ // The inliner should only be run on operations that define a symbol table,
+ // as the callgraph will need to resolve references.
+ Operation *op = getOperation();
+ if (!op->hasTrait<OpTrait::SymbolTable>()) {
+ op->emitOpError() << " was scheduled to run under the inliner, but does "
+ "not define a symbol table";
+ return signalPassFailure();
+ }
+
// Collect a set of canonicalization patterns to use when simplifying
// callable regions within an SCC.
OwningRewritePatternList canonPatterns;
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 64591209dce..e91bdb71e9e 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -284,7 +284,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
if (src->empty())
return failure();
auto *entryBlock = &src->front();
- ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
+ ArrayRef<Type> callableResultTypes = callable.getCallableResults();
// Make sure that the number of arguments and results matchup between the call
// and the region.
diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir
index 39e4fb8ba27..8c295ff248e 100644
--- a/mlir/test/Analysis/test-callgraph.mlir
+++ b/mlir/test/Analysis/test-callgraph.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-print-callgraph 2>&1 | FileCheck %s --dump-input-on-failure
+// RUN: mlir-opt %s -test-print-callgraph -split-input-file 2>&1 | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: Testing : "simple"
module attributes {test.name = "simple"} {
@@ -50,3 +50,22 @@ module attributes {test.name = "simple"} {
return
}
}
+
+// -----
+
+// CHECK-LABEL: Testing : "nested"
+module attributes {test.name = "nested"} {
+ module @nested_module {
+ // CHECK: Node{{.*}}func_a
+ func @func_a() {
+ return
+ }
+ }
+
+ // CHECK: Node{{.*}}func_b
+ // CHECK: Call-Edge{{.*}}func_a
+ func @func_b() {
+ "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> ()
+ return
+ }
+}
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 36eb545a39e..f10991dfe5b 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -230,7 +230,7 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
def ConversionCallOp : TEST_Op<"conversion_call_op",
[CallOpInterface]> {
- let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee);
+ let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
@@ -239,7 +239,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
/// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
- return getAttrOfType<FlatSymbolRefAttr>("callee");
+ return getAttrOfType<SymbolRefAttr>("callee");
}
}];
}
@@ -250,11 +250,8 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
let results = (outs FunctionType);
let extraClassDeclaration = [{
- Region *getCallableRegion(CallInterfaceCallable) { return &body(); }
- void getCallableRegions(SmallVectorImpl<Region *> &callables) {
- callables.push_back(&body());
- }
- ArrayRef<Type> getCallableResults(Region *) {
+ Region *getCallableRegion() { return &body(); }
+ ArrayRef<Type> getCallableResults() {
return getType().cast<FunctionType>().getResults();
}
}];
OpenPOWER on IntegriCloud