summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-10-03 23:04:56 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-03 23:05:21 -0700
commita20d96e436272b52d36f52c4a07c86ed285502e9 (patch)
treea11b0ee1c681d4ce4b1705d231ec3af1d753f288 /mlir
parent8c95223e3c9555165fb9f56db35c3c8e85ddd4c1 (diff)
downloadbcm5719-llvm-a20d96e436272b52d36f52c4a07c86ed285502e9.tar.gz
bcm5719-llvm-a20d96e436272b52d36f52c4a07c86ed285502e9.zip
Update the Inliner pass to work on SCCs of the CallGraph.
This allows for the inliner to work on arbitrary call operations. The updated inliner will also work bottom-up through the callgraph enabling support for multiple levels of inlining. PiperOrigin-RevId: 272813876
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/CallInterfaces.td9
-rw-r--r--mlir/include/mlir/Transforms/InliningUtils.h6
-rw-r--r--mlir/lib/Transforms/Inliner.cpp186
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp3
-rw-r--r--mlir/test/Transforms/inlining.mlir33
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp4
6 files changed, 214 insertions, 27 deletions
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td
index de2dce98d37..fca7773ce63 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.td
+++ b/mlir/include/mlir/Analysis/CallInterfaces.td
@@ -39,7 +39,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
let description = [{
A call-like operation is one that transfers control from one sub-routine to
another. These operations may be traditional direct calls `call @foo`, or
- indirect calls to other operations `call_indirect %foo`.
+ indirect calls to other operations `call_indirect %foo`. An operation that
+ uses this interface, must *not* also provide the `CallableOpInterface`.
}];
let methods = [
@@ -50,6 +51,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"CallInterfaceCallable", "getCallableForCallee"
>,
+ InterfaceMethod<[{
+ Returns the operands within this call that are used as arguments to the
+ callee.
+ }],
+ "Operation::operand_range", "getArgOperands"
+ >,
];
}
diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index 693f20f9e10..7fe67e78127 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -24,6 +24,7 @@
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/Region.h"
namespace mlir {
@@ -116,6 +117,11 @@ public:
using Base::Base;
virtual ~InlinerInterface();
+ /// Process a set of blocks that have been inlined. This callback is invoked
+ /// *before* inlined terminator operations have been processed.
+ virtual void
+ processInlinedBlocks(llvm::iterator_range<Region::iterator> inlinedBlocks) {}
+
/// These hooks mirror the hooks for the DialectInlinerInterface, with default
/// implementations that call the hook on the handler for the dialect 'op' is
/// registered to.
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 49685cadba5..afb2dccc241 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -14,46 +14,180 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+//
+// This file implements a basic inlining algorithm that operates bottom up over
+// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
+// incremental propagation of inlining decisions from the leafs to the roots of
+// the callgraph.
+//
+//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/SCCIterator.h"
using namespace mlir;
-// TODO(riverriddle) This pass should currently only be used for basic testing
-// of inlining functionality.
+//===----------------------------------------------------------------------===//
+// CallGraph traversal
+//===----------------------------------------------------------------------===//
+
+/// Run a given transformation over the SCCs of the callgraph in a bottom up
+/// traversal.
+static void runTransformOnCGSCCs(
+ const CallGraph &cg,
+ function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
+ for (auto cgi = llvm::scc_begin(&cg); !cgi.isAtEnd(); ++cgi)
+ sccTransformer(*cgi);
+}
+
namespace {
-struct Inliner : public ModulePass<Inliner> {
- void runOnModule() override {
- auto module = getModule();
-
- // Collect each of the direct function calls within the module.
- SmallVector<CallOp, 16> callOps;
- for (auto &f : module)
- f.walk([&](CallOp callOp) { callOps.push_back(callOp); });
-
- // Build the inliner interface.
- InlinerInterface interface(&getContext());
-
- // Try to inline each of the call operations.
- for (auto &call : callOps) {
- if (failed(inlineFunction(
- interface, module.lookupSymbol<FuncOp>(call.getCallee()), call,
- llvm::to_vector<8>(call.getArgOperands()),
- llvm::to_vector<8>(call.getResults()), call.getLoc())))
+/// This struct represents a resolved call to a given callgraph node. Given that
+/// the call does not actually contain a direct reference to the
+/// Region(CallGraphNode) that it is dispatching to, we need to resolve them
+/// explicitly.
+struct ResolvedCall {
+ ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
+ : call(call), targetNode(targetNode) {}
+ CallOpInterface call;
+ CallGraphNode *targetNode;
+};
+} // end anonymous namespace
+
+/// Collect all of the callable operations within the given range of blocks. If
+/// `traverseNestedCGNodes` is true, this will also collect call operations
+/// inside of nested callgraph nodes.
+static void collectCallOps(llvm::iterator_range<Region::iterator> blocks,
+ CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
+ bool traverseNestedCGNodes) {
+ SmallVector<Block *, 8> worklist;
+ auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) {
+ for (Block &block : blocks)
+ worklist.push_back(&block);
+ };
+
+ addToWorklist(blocks);
+ while (!worklist.empty()) {
+ for (Operation &op : *worklist.pop_back_val()) {
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ CallGraphNode *node =
+ cg.resolveCallable(call.getCallableForCallee(), &op);
+ if (!node->isExternal())
+ calls.emplace_back(call, node);
continue;
+ }
- // If the inlining was successful then erase the call.
- call.erase();
+ // If this is not a call, traverse the nested regions. If
+ // `traverseNestedCGNodes` is false, then don't traverse nested call graph
+ // regions.
+ for (auto &nestedRegion : op.getRegions())
+ if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
+ addToWorklist(nestedRegion);
}
}
+}
+
+//===----------------------------------------------------------------------===//
+// Inliner
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class provides a specialization of the main inlining interface.
+struct Inliner : public InlinerInterface {
+ Inliner(MLIRContext *context, CallGraph &cg)
+ : InlinerInterface(context), cg(cg) {}
+
+ /// Process a set of blocks that have been inlined. This callback is invoked
+ /// *before* inlined terminator operations have been processed.
+ void processInlinedBlocks(
+ llvm::iterator_range<Region::iterator> inlinedBlocks) final {
+ collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
+ }
+
+ /// The current set of call instructions to consider for inlining.
+ SmallVector<ResolvedCall, 8> calls;
+
+ /// The callgraph being operated on.
+ CallGraph &cg;
+};
+} // namespace
+
+/// Returns true if the given call should be inlined.
+static bool shouldInline(ResolvedCall &resolvedCall) {
+ // Don't allow inlining terminator calls. We currently don't support this
+ // case.
+ if (resolvedCall.call.getOperation()->isKnownTerminator())
+ return false;
+
+ // Don't allow inlining if the target is an ancestor of the call. This
+ // prevents inlining recursively.
+ if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
+ resolvedCall.call.getParentRegion()))
+ return false;
+
+ // Otherwise, inline.
+ return true;
+}
+
+/// Attempt to inline calls within the given scc.
+static void inlineCallsInSCC(Inliner &inliner,
+ ArrayRef<CallGraphNode *> currentSCC) {
+ CallGraph &cg = inliner.cg;
+ auto &calls = inliner.calls;
+
+ // Collect all of the direct calls within the nodes of the current SCC. We
+ // don't traverse nested callgraph nodes, because they are handled separately
+ // likely within a different SCC.
+ for (auto *node : currentSCC) {
+ if (!node->isExternal())
+ collectCallOps(*node->getCallableRegion(), cg, calls,
+ /*traverseNestedCGNodes=*/false);
+ }
+ if (calls.empty())
+ return;
+
+ // Try to inline each of the call operations. Don't cache the end iterator
+ // here as more calls may be added during inlining.
+ for (unsigned i = 0; i != calls.size(); ++i) {
+ ResolvedCall &it = calls[i];
+ if (!shouldInline(it))
+ continue;
+
+ CallOpInterface call = it.call;
+ LogicalResult inlineResult = inlineRegion(
+ inliner, it.targetNode->getCallableRegion(), call,
+ llvm::to_vector<8>(call.getArgOperands()),
+ llvm::to_vector<8>(call.getOperation()->getResults()), call.getLoc());
+ if (failed(inlineResult))
+ continue;
+
+ // If the inlining was successful, then erase the call.
+ call.erase();
+ }
+ calls.clear();
+}
+
+//===----------------------------------------------------------------------===//
+// InlinerPass
+//===----------------------------------------------------------------------===//
+
+// TODO(riverriddle) This pass should currently only be used for basic testing
+// of inlining functionality.
+namespace {
+struct InlinerPass : public OperationPass<InlinerPass> {
+ void runOnOperation() override {
+ CallGraph &cg = getAnalysis<CallGraph>();
+ Inliner inliner(&getContext(), cg);
+
+ // Run the inline transform in post-order over the SCCs in the callgraph.
+ runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
+ inlineCallsInSCC(inliner, scc);
+ });
+ }
};
} // end anonymous namespace
-static PassRegistration<Inliner> pass("inline", "Inline function calls");
+static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 901599ce023..6ca875b25ae 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -186,6 +186,9 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
if (!shouldCloneInlinedRegion)
remapInlinedOperands(newBlocks, mapper);
+ // Process the newly inlined blocks.
+ interface.processInlinedBlocks(newBlocks);
+
// Handle the case where only a single block was inlined.
if (std::next(newBlocks.begin()) == newBlocks.end()) {
// Have the interface handle the terminator of this block.
diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index 3cfb5eee5c7..9732992b013 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -72,3 +72,36 @@ func @no_inline_external() {
call @func_external() : () -> ()
return
}
+
+// Check that multiple levels of calls will be inlined.
+func @multilevel_func_a() {
+ return
+}
+func @multilevel_func_b() {
+ call @multilevel_func_a() : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @inline_multilevel
+func @inline_multilevel() {
+ // CHECK-NOT: call
+ %fn = "test.functional_region_op"() ({
+ call @multilevel_func_b() : () -> ()
+ "test.return"() : () -> ()
+ }) : () -> (() -> ())
+
+ call_indirect %fn() : () -> ()
+ return
+}
+
+// Check that recursive calls are not inlined.
+// CHECK-LABEL: func @no_inline_recursive
+func @no_inline_recursive() {
+ // CHECK: test.functional_region_op
+ // CHECK-NOT: test.functional_region_op
+ %fn = "test.functional_region_op"() ({
+ call @no_inline_recursive() : () -> ()
+ "test.return"() : () -> ()
+ }) : () -> (() -> ())
+ return
+}
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index d91bb1a2f57..ca523d8a52f 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -49,6 +49,10 @@ struct TestInlinerInterface : public DialectInlinerInterface {
// Analysis Hooks
//===--------------------------------------------------------------------===//
+ bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
+ // Inlining into test dialect regions is legal.
+ return true;
+ }
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
OpenPOWER on IntegriCloud