summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Inliner.cpp
diff options
context:
space:
mode:
authorMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
committerMehdi Amini <aminim@google.com>2019-12-24 02:47:41 +0000
commit0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch)
tree31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/lib/Transforms/Inliner.cpp
parent6f635f90929da9545dd696071a829a1a42f84b30 (diff)
parent5b4a01d4a63cb66ab981e52548f940813393bf42 (diff)
downloadbcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz
bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/lib/Transforms/Inliner.cpp')
-rw-r--r--mlir/lib/Transforms/Inliner.cpp296
1 files changed, 296 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
new file mode 100644
index 00000000000..b2cee7da083
--- /dev/null
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -0,0 +1,296 @@
+//===- Inliner.cpp - Pass to inline function calls ------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a basic inlining algorithm that operates bottom up over
+// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
+// incremental propagation of inlining decisions from the leafs to the roots of
+// the callgraph.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/CallGraph.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Parallel.h"
+
+#define DEBUG_TYPE "inlining"
+
+using namespace mlir;
+
+static llvm::cl::opt<bool> disableCanonicalization(
+ "mlir-disable-inline-simplify",
+ llvm::cl::desc("Disable running simplifications during inlining"),
+ llvm::cl::ReallyHidden, llvm::cl::init(false));
+
+static llvm::cl::opt<unsigned> maxInliningIterations(
+ "mlir-max-inline-iterations",
+ llvm::cl::desc("Maximum number of iterations when inlining within an SCC"),
+ llvm::cl::ReallyHidden, llvm::cl::init(4));
+
+//===----------------------------------------------------------------------===//
+// CallGraph traversal
+//===----------------------------------------------------------------------===//
+
+/// Run a given transformation over the SCCs of the callgraph in a bottom up
+/// traversal.
+static void runTransformOnCGSCCs(
+ const CallGraph &cg,
+ function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
+ std::vector<CallGraphNode *> currentSCCVec;
+ auto cgi = llvm::scc_begin(&cg);
+ while (!cgi.isAtEnd()) {
+ // Copy the current SCC and increment so that the transformer can modify the
+ // SCC without invalidating our iterator.
+ currentSCCVec = *cgi;
+ ++cgi;
+ sccTransformer(currentSCCVec);
+ }
+}
+
+namespace {
+/// This struct represents a resolved call to a given callgraph node. Given that
+/// the call does not actually contain a direct reference to the
+/// Region(CallGraphNode) that it is dispatching to, we need to resolve them
+/// explicitly.
+struct ResolvedCall {
+ ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
+ : call(call), targetNode(targetNode) {}
+ CallOpInterface call;
+ CallGraphNode *targetNode;
+};
+} // end anonymous namespace
+
+/// Collect all of the callable operations within the given range of blocks. If
+/// `traverseNestedCGNodes` is true, this will also collect call operations
+/// inside of nested callgraph nodes.
+static void collectCallOps(iterator_range<Region::iterator> blocks,
+ CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
+ bool traverseNestedCGNodes) {
+ SmallVector<Block *, 8> worklist;
+ auto addToWorklist = [&](iterator_range<Region::iterator> blocks) {
+ for (Block &block : blocks)
+ worklist.push_back(&block);
+ };
+
+ addToWorklist(blocks);
+ while (!worklist.empty()) {
+ for (Operation &op : *worklist.pop_back_val()) {
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ CallGraphNode *node =
+ cg.resolveCallable(call.getCallableForCallee(), &op);
+ if (!node->isExternal())
+ calls.emplace_back(call, node);
+ continue;
+ }
+
+ // If this is not a call, traverse the nested regions. If
+ // `traverseNestedCGNodes` is false, then don't traverse nested call graph
+ // regions.
+ for (auto &nestedRegion : op.getRegions())
+ if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
+ addToWorklist(nestedRegion);
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Inliner
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class provides a specialization of the main inlining interface.
+struct Inliner : public InlinerInterface {
+ Inliner(MLIRContext *context, CallGraph &cg)
+ : InlinerInterface(context), cg(cg) {}
+
+ /// Process a set of blocks that have been inlined. This callback is invoked
+ /// *before* inlined terminator operations have been processed.
+ void
+ processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
+ collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
+ }
+
+ /// The current set of call instructions to consider for inlining.
+ SmallVector<ResolvedCall, 8> calls;
+
+ /// The callgraph being operated on.
+ CallGraph &cg;
+};
+} // namespace
+
+/// Returns true if the given call should be inlined.
+static bool shouldInline(ResolvedCall &resolvedCall) {
+ // Don't allow inlining terminator calls. We currently don't support this
+ // case.
+ if (resolvedCall.call.getOperation()->isKnownTerminator())
+ return false;
+
+ // Don't allow inlining if the target is an ancestor of the call. This
+ // prevents inlining recursively.
+ if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
+ resolvedCall.call.getParentRegion()))
+ return false;
+
+ // Otherwise, inline.
+ return true;
+}
+
+/// Attempt to inline calls within the given scc. This function returns
+/// success if any calls were inlined, failure otherwise.
+static LogicalResult inlineCallsInSCC(Inliner &inliner,
+ ArrayRef<CallGraphNode *> currentSCC) {
+ CallGraph &cg = inliner.cg;
+ auto &calls = inliner.calls;
+
+ // Collect all of the direct calls within the nodes of the current SCC. We
+ // don't traverse nested callgraph nodes, because they are handled separately
+ // likely within a different SCC.
+ for (auto *node : currentSCC) {
+ if (!node->isExternal())
+ collectCallOps(*node->getCallableRegion(), cg, calls,
+ /*traverseNestedCGNodes=*/false);
+ }
+ if (calls.empty())
+ return failure();
+
+ // Try to inline each of the call operations. Don't cache the end iterator
+ // here as more calls may be added during inlining.
+ bool inlinedAnyCalls = false;
+ for (unsigned i = 0; i != calls.size(); ++i) {
+ ResolvedCall &it = calls[i];
+ LLVM_DEBUG({
+ llvm::dbgs() << "* Considering inlining call: ";
+ it.call.dump();
+ });
+ if (!shouldInline(it))
+ continue;
+
+ CallOpInterface call = it.call;
+ Region *targetRegion = it.targetNode->getCallableRegion();
+ LogicalResult inlineResult = inlineCall(
+ inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
+ targetRegion);
+ if (failed(inlineResult))
+ continue;
+
+ // If the inlining was successful, then erase the call.
+ call.erase();
+ inlinedAnyCalls = true;
+ }
+ calls.clear();
+ return success(inlinedAnyCalls);
+}
+
+/// Canonicalize the nodes within the given SCC with the given set of
+/// canonicalization patterns.
+static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC,
+ MLIRContext *context,
+ const OwningRewritePatternList &canonPatterns) {
+ // Collect the sets of nodes to canonicalize.
+ SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
+ for (auto *node : currentSCC) {
+ // Don't canonicalize the external node, it has no valid callable region.
+ if (node->isExternal())
+ continue;
+
+ // Don't canonicalize nodes with children. Nodes with children
+ // require special handling as we may remove the node during
+ // canonicalization. In the future, we should be able to handle this
+ // case with proper node deletion tracking.
+ if (node->hasChildren())
+ continue;
+
+ // We also won't apply canonicalizations for nodes that are not
+ // isolated. This avoids potentially mutating the regions of nodes defined
+ // above, this is also a stipulation of the 'applyPatternsGreedily' driver.
+ auto *region = node->getCallableRegion();
+ if (!region->getParentOp()->isKnownIsolatedFromAbove())
+ continue;
+ nodesToCanonicalize.push_back(node);
+ }
+ if (nodesToCanonicalize.empty())
+ return;
+
+ // Canonicalize each of the nodes within the SCC in parallel.
+ // NOTE: This is simple now, because we don't enable canonicalizing nodes
+ // within children. When we remove this restriction, this logic will need to
+ // be reworked.
+ ParallelDiagnosticHandler canonicalizationHandler(context);
+ llvm::parallel::for_each_n(
+ llvm::parallel::par, /*Begin=*/size_t(0),
+ /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
+ // Set the order for this thread so that diagnostics will be properly
+ // ordered.
+ canonicalizationHandler.setOrderIDForThread(index);
+
+ // Apply the canonicalization patterns to this region.
+ auto *node = nodesToCanonicalize[index];
+ applyPatternsGreedily(*node->getCallableRegion(), canonPatterns);
+
+ // Make sure to reset the order ID for the diagnostic handler, as this
+ // thread may be used in a different context.
+ canonicalizationHandler.eraseOrderIDForThread();
+ });
+}
+
+/// Attempt to inline calls within the given scc, and run canonicalizations with
+/// the given patterns, until a fixed point is reached. This allows for the
+/// inlining of newly devirtualized calls.
+static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC,
+ MLIRContext *context,
+ const OwningRewritePatternList &canonPatterns) {
+ // If we successfully inlined any calls, run some simplifications on the
+ // nodes of the scc. Continue attempting to inline until we reach a fixed
+ // point, or a maximum iteration count. We canonicalize here as it may
+ // devirtualize new calls, as well as give us a better cost model.
+ unsigned iterationCount = 0;
+ while (succeeded(inlineCallsInSCC(inliner, currentSCC))) {
+ // If we aren't allowing simplifications or the max iteration count was
+ // reached, then bail out early.
+ if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
+ break;
+ canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// InlinerPass
+//===----------------------------------------------------------------------===//
+
+// TODO(riverriddle) This pass should currently only be used for basic testing
+// of inlining functionality.
+namespace {
+struct InlinerPass : public OperationPass<InlinerPass> {
+ void runOnOperation() override {
+ CallGraph &cg = getAnalysis<CallGraph>();
+ auto *context = &getContext();
+
+ // Collect a set of canonicalization patterns to use when simplifying
+ // callable regions within an SCC.
+ OwningRewritePatternList canonPatterns;
+ for (auto *op : context->getRegisteredOperations())
+ op->getCanonicalizationPatterns(canonPatterns, context);
+
+ // Run the inline transform in post-order over the SCCs in the callgraph.
+ Inliner inliner(context, cg);
+ runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
+ inlineSCC(inliner, scc, context, canonPatterns);
+ });
+ }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createInlinerPass() {
+ return std::make_unique<InlinerPass>();
+}
+
+static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
OpenPOWER on IntegriCloud