//===- Inliner.cpp - Pass to inline function calls ------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a basic inlining algorithm that operates bottom up over // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more // incremental propagation of inlining decisions from the leafs to the roots of // the callgraph. // //===----------------------------------------------------------------------===// #include "mlir/Analysis/CallGraph.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Parallel.h" #define DEBUG_TYPE "inlining" using namespace mlir; static llvm::cl::opt disableCanonicalization( "mlir-disable-inline-simplify", llvm::cl::desc("Disable running simplifications during inlining"), llvm::cl::ReallyHidden, llvm::cl::init(false)); static llvm::cl::opt maxInliningIterations( "mlir-max-inline-iterations", llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), llvm::cl::ReallyHidden, llvm::cl::init(4)); //===----------------------------------------------------------------------===// // CallGraph traversal //===----------------------------------------------------------------------===// /// Run a given transformation over the SCCs of the callgraph in a bottom up /// traversal. static void runTransformOnCGSCCs( const CallGraph &cg, function_ref)> sccTransformer) { std::vector currentSCCVec; auto cgi = llvm::scc_begin(&cg); while (!cgi.isAtEnd()) { // Copy the current SCC and increment so that the transformer can modify the // SCC without invalidating our iterator. currentSCCVec = *cgi; ++cgi; sccTransformer(currentSCCVec); } } namespace { /// This struct represents a resolved call to a given callgraph node. Given that /// the call does not actually contain a direct reference to the /// Region(CallGraphNode) that it is dispatching to, we need to resolve them /// explicitly. struct ResolvedCall { ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) : call(call), targetNode(targetNode) {} CallOpInterface call; CallGraphNode *targetNode; }; } // end anonymous namespace /// Collect all of the callable operations within the given range of blocks. If /// `traverseNestedCGNodes` is true, this will also collect call operations /// inside of nested callgraph nodes. static void collectCallOps(iterator_range blocks, CallGraph &cg, SmallVectorImpl &calls, bool traverseNestedCGNodes) { SmallVector worklist; auto addToWorklist = [&](iterator_range blocks) { for (Block &block : blocks) worklist.push_back(&block); }; addToWorklist(blocks); while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast(op)) { CallInterfaceCallable callable = call.getCallableForCallee(); // TODO(riverriddle) Support inlining nested call references. if (SymbolRefAttr symRef = callable.dyn_cast()) { if (!symRef.isa()) continue; } CallGraphNode *node = cg.resolveCallable(callable, &op); if (!node->isExternal()) calls.emplace_back(call, node); continue; } // If this is not a call, traverse the nested regions. If // `traverseNestedCGNodes` is false, then don't traverse nested call graph // regions. for (auto &nestedRegion : op.getRegions()) if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) addToWorklist(nestedRegion); } } } //===----------------------------------------------------------------------===// // Inliner //===----------------------------------------------------------------------===// namespace { /// This class provides a specialization of the main inlining interface. struct Inliner : public InlinerInterface { Inliner(MLIRContext *context, CallGraph &cg) : InlinerInterface(context), cg(cg) {} /// Process a set of blocks that have been inlined. This callback is invoked /// *before* inlined terminator operations have been processed. void processInlinedBlocks(iterator_range inlinedBlocks) final { collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); } /// The current set of call instructions to consider for inlining. SmallVector calls; /// The callgraph being operated on. CallGraph &cg; }; } // namespace /// Returns true if the given call should be inlined. static bool shouldInline(ResolvedCall &resolvedCall) { // Don't allow inlining terminator calls. We currently don't support this // case. if (resolvedCall.call.getOperation()->isKnownTerminator()) return false; // Don't allow inlining if the target is an ancestor of the call. This // prevents inlining recursively. if (resolvedCall.targetNode->getCallableRegion()->isAncestor( resolvedCall.call.getParentRegion())) return false; // Otherwise, inline. return true; } /// Attempt to inline calls within the given scc. This function returns /// success if any calls were inlined, failure otherwise. static LogicalResult inlineCallsInSCC(Inliner &inliner, ArrayRef currentSCC) { CallGraph &cg = inliner.cg; auto &calls = inliner.calls; // Collect all of the direct calls within the nodes of the current SCC. We // don't traverse nested callgraph nodes, because they are handled separately // likely within a different SCC. for (auto *node : currentSCC) { if (!node->isExternal()) collectCallOps(*node->getCallableRegion(), cg, calls, /*traverseNestedCGNodes=*/false); } if (calls.empty()) return failure(); // Try to inline each of the call operations. Don't cache the end iterator // here as more calls may be added during inlining. bool inlinedAnyCalls = false; for (unsigned i = 0; i != calls.size(); ++i) { ResolvedCall &it = calls[i]; LLVM_DEBUG({ llvm::dbgs() << "* Considering inlining call: "; it.call.dump(); }); if (!shouldInline(it)) continue; CallOpInterface call = it.call; Region *targetRegion = it.targetNode->getCallableRegion(); LogicalResult inlineResult = inlineCall( inliner, call, cast(targetRegion->getParentOp()), targetRegion); if (failed(inlineResult)) continue; // If the inlining was successful, then erase the call. call.erase(); inlinedAnyCalls = true; } calls.clear(); return success(inlinedAnyCalls); } /// Canonicalize the nodes within the given SCC with the given set of /// canonicalization patterns. static void canonicalizeSCC(CallGraph &cg, ArrayRef currentSCC, MLIRContext *context, const OwningRewritePatternList &canonPatterns) { // Collect the sets of nodes to canonicalize. SmallVector nodesToCanonicalize; for (auto *node : currentSCC) { // Don't canonicalize the external node, it has no valid callable region. if (node->isExternal()) continue; // Don't canonicalize nodes with children. Nodes with children // require special handling as we may remove the node during // canonicalization. In the future, we should be able to handle this // case with proper node deletion tracking. if (node->hasChildren()) continue; // We also won't apply canonicalizations for nodes that are not // isolated. This avoids potentially mutating the regions of nodes defined // above, this is also a stipulation of the 'applyPatternsGreedily' driver. auto *region = node->getCallableRegion(); if (!region->getParentOp()->isKnownIsolatedFromAbove()) continue; nodesToCanonicalize.push_back(node); } if (nodesToCanonicalize.empty()) return; // Canonicalize each of the nodes within the SCC in parallel. // NOTE: This is simple now, because we don't enable canonicalizing nodes // within children. When we remove this restriction, this logic will need to // be reworked. ParallelDiagnosticHandler canonicalizationHandler(context); llvm::parallel::for_each_n( llvm::parallel::par, /*Begin=*/size_t(0), /*End=*/nodesToCanonicalize.size(), [&](size_t index) { // Set the order for this thread so that diagnostics will be properly // ordered. canonicalizationHandler.setOrderIDForThread(index); // Apply the canonicalization patterns to this region. auto *node = nodesToCanonicalize[index]; applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); // Make sure to reset the order ID for the diagnostic handler, as this // thread may be used in a different context. canonicalizationHandler.eraseOrderIDForThread(); }); } /// Attempt to inline calls within the given scc, and run canonicalizations with /// the given patterns, until a fixed point is reached. This allows for the /// inlining of newly devirtualized calls. static void inlineSCC(Inliner &inliner, ArrayRef currentSCC, MLIRContext *context, const OwningRewritePatternList &canonPatterns) { // If we successfully inlined any calls, run some simplifications on the // nodes of the scc. Continue attempting to inline until we reach a fixed // point, or a maximum iteration count. We canonicalize here as it may // devirtualize new calls, as well as give us a better cost model. unsigned iterationCount = 0; while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { // If we aren't allowing simplifications or the max iteration count was // reached, then bail out early. if (disableCanonicalization || ++iterationCount >= maxInliningIterations) break; canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); } } //===----------------------------------------------------------------------===// // InlinerPass //===----------------------------------------------------------------------===// // TODO(riverriddle) This pass should currently only be used for basic testing // of inlining functionality. namespace { struct InlinerPass : public OperationPass { void runOnOperation() override { CallGraph &cg = getAnalysis(); auto *context = &getContext(); // The inliner should only be run on operations that define a symbol table, // as the callgraph will need to resolve references. Operation *op = getOperation(); if (!op->hasTrait()) { op->emitOpError() << " was scheduled to run under the inliner, but does " "not define a symbol table"; return signalPassFailure(); } // Collect a set of canonicalization patterns to use when simplifying // callable regions within an SCC. OwningRewritePatternList canonPatterns; for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(canonPatterns, context); // Run the inline transform in post-order over the SCCs in the callgraph. Inliner inliner(context, cg); runTransformOnCGSCCs(cg, [&](ArrayRef scc) { inlineSCC(inliner, scc, context, canonPatterns); }); } }; } // end anonymous namespace std::unique_ptr mlir::createInlinerPass() { return std::make_unique(); } static PassRegistration pass("inline", "Inline function calls");