diff options
author | River Riddle <riverriddle@google.com> | 2019-10-03 23:04:56 -0700 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-10-03 23:05:21 -0700 |
commit | a20d96e436272b52d36f52c4a07c86ed285502e9 (patch) | |
tree | a11b0ee1c681d4ce4b1705d231ec3af1d753f288 /mlir | |
parent | 8c95223e3c9555165fb9f56db35c3c8e85ddd4c1 (diff) | |
download | bcm5719-llvm-a20d96e436272b52d36f52c4a07c86ed285502e9.tar.gz bcm5719-llvm-a20d96e436272b52d36f52c4a07c86ed285502e9.zip |
Update the Inliner pass to work on SCCs of the CallGraph.
This allows for the inliner to work on arbitrary call operations. The updated inliner will also work bottom-up through the callgraph enabling support for multiple levels of inlining.
PiperOrigin-RevId: 272813876
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Analysis/CallInterfaces.td | 9 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/InliningUtils.h | 6 | ||||
-rw-r--r-- | mlir/lib/Transforms/Inliner.cpp | 186 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 3 | ||||
-rw-r--r-- | mlir/test/Transforms/inlining.mlir | 33 | ||||
-rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.cpp | 4 |
6 files changed, 214 insertions, 27 deletions
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td index de2dce98d37..fca7773ce63 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.td +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -39,7 +39,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { let description = [{ A call-like operation is one that transfers control from one sub-routine to another. These operations may be traditional direct calls `call @foo`, or - indirect calls to other operations `call_indirect %foo`. + indirect calls to other operations `call_indirect %foo`. An operation that + uses this interface, must *not* also provide the `CallableOpInterface`. }]; let methods = [ @@ -50,6 +51,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { }], "CallInterfaceCallable", "getCallableForCallee" >, + InterfaceMethod<[{ + Returns the operands within this call that are used as arguments to the + callee. + }], + "Operation::operand_range", "getArgOperands" + >, ]; } diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index 693f20f9e10..7fe67e78127 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -24,6 +24,7 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Region.h" namespace mlir { @@ -116,6 +117,11 @@ public: using Base::Base; virtual ~InlinerInterface(); + /// Process a set of blocks that have been inlined. This callback is invoked + /// *before* inlined terminator operations have been processed. + virtual void + processInlinedBlocks(llvm::iterator_range<Region::iterator> inlinedBlocks) {} + /// These hooks mirror the hooks for the DialectInlinerInterface, with default /// implementations that call the hook on the handler for the dialect 'op' is /// registered to. diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 49685cadba5..afb2dccc241 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -14,46 +14,180 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= +// +// This file implements a basic inlining algorithm that operates bottom up over +// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more +// incremental propagation of inlining decisions from the leafs to the roots of +// the callgraph. +// +//===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Analysis/CallGraph.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Passes.h" -#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/SCCIterator.h" using namespace mlir; -// TODO(riverriddle) This pass should currently only be used for basic testing -// of inlining functionality. +//===----------------------------------------------------------------------===// +// CallGraph traversal +//===----------------------------------------------------------------------===// + +/// Run a given transformation over the SCCs of the callgraph in a bottom up +/// traversal. +static void runTransformOnCGSCCs( + const CallGraph &cg, + function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { + for (auto cgi = llvm::scc_begin(&cg); !cgi.isAtEnd(); ++cgi) + sccTransformer(*cgi); +} + namespace { -struct Inliner : public ModulePass<Inliner> { - void runOnModule() override { - auto module = getModule(); - - // Collect each of the direct function calls within the module. - SmallVector<CallOp, 16> callOps; - for (auto &f : module) - f.walk([&](CallOp callOp) { callOps.push_back(callOp); }); - - // Build the inliner interface. - InlinerInterface interface(&getContext()); - - // Try to inline each of the call operations. - for (auto &call : callOps) { - if (failed(inlineFunction( - interface, module.lookupSymbol<FuncOp>(call.getCallee()), call, - llvm::to_vector<8>(call.getArgOperands()), - llvm::to_vector<8>(call.getResults()), call.getLoc()))) +/// This struct represents a resolved call to a given callgraph node. Given that +/// the call does not actually contain a direct reference to the +/// Region(CallGraphNode) that it is dispatching to, we need to resolve them +/// explicitly. +struct ResolvedCall { + ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) + : call(call), targetNode(targetNode) {} + CallOpInterface call; + CallGraphNode *targetNode; +}; +} // end anonymous namespace + +/// Collect all of the callable operations within the given range of blocks. If +/// `traverseNestedCGNodes` is true, this will also collect call operations +/// inside of nested callgraph nodes. +static void collectCallOps(llvm::iterator_range<Region::iterator> blocks, + CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, + bool traverseNestedCGNodes) { + SmallVector<Block *, 8> worklist; + auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) { + for (Block &block : blocks) + worklist.push_back(&block); + }; + + addToWorklist(blocks); + while (!worklist.empty()) { + for (Operation &op : *worklist.pop_back_val()) { + if (auto call = dyn_cast<CallOpInterface>(op)) { + CallGraphNode *node = + cg.resolveCallable(call.getCallableForCallee(), &op); + if (!node->isExternal()) + calls.emplace_back(call, node); continue; + } - // If the inlining was successful then erase the call. - call.erase(); + // If this is not a call, traverse the nested regions. If + // `traverseNestedCGNodes` is false, then don't traverse nested call graph + // regions. + for (auto &nestedRegion : op.getRegions()) + if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) + addToWorklist(nestedRegion); } } +} + +//===----------------------------------------------------------------------===// +// Inliner +//===----------------------------------------------------------------------===// +namespace { +/// This class provides a specialization of the main inlining interface. +struct Inliner : public InlinerInterface { + Inliner(MLIRContext *context, CallGraph &cg) + : InlinerInterface(context), cg(cg) {} + + /// Process a set of blocks that have been inlined. This callback is invoked + /// *before* inlined terminator operations have been processed. + void processInlinedBlocks( + llvm::iterator_range<Region::iterator> inlinedBlocks) final { + collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); + } + + /// The current set of call instructions to consider for inlining. + SmallVector<ResolvedCall, 8> calls; + + /// The callgraph being operated on. + CallGraph &cg; +}; +} // namespace + +/// Returns true if the given call should be inlined. +static bool shouldInline(ResolvedCall &resolvedCall) { + // Don't allow inlining terminator calls. We currently don't support this + // case. + if (resolvedCall.call.getOperation()->isKnownTerminator()) + return false; + + // Don't allow inlining if the target is an ancestor of the call. This + // prevents inlining recursively. + if (resolvedCall.targetNode->getCallableRegion()->isAncestor( + resolvedCall.call.getParentRegion())) + return false; + + // Otherwise, inline. + return true; +} + +/// Attempt to inline calls within the given scc. +static void inlineCallsInSCC(Inliner &inliner, + ArrayRef<CallGraphNode *> currentSCC) { + CallGraph &cg = inliner.cg; + auto &calls = inliner.calls; + + // Collect all of the direct calls within the nodes of the current SCC. We + // don't traverse nested callgraph nodes, because they are handled separately + // likely within a different SCC. + for (auto *node : currentSCC) { + if (!node->isExternal()) + collectCallOps(*node->getCallableRegion(), cg, calls, + /*traverseNestedCGNodes=*/false); + } + if (calls.empty()) + return; + + // Try to inline each of the call operations. Don't cache the end iterator + // here as more calls may be added during inlining. + for (unsigned i = 0; i != calls.size(); ++i) { + ResolvedCall &it = calls[i]; + if (!shouldInline(it)) + continue; + + CallOpInterface call = it.call; + LogicalResult inlineResult = inlineRegion( + inliner, it.targetNode->getCallableRegion(), call, + llvm::to_vector<8>(call.getArgOperands()), + llvm::to_vector<8>(call.getOperation()->getResults()), call.getLoc()); + if (failed(inlineResult)) + continue; + + // If the inlining was successful, then erase the call. + call.erase(); + } + calls.clear(); +} + +//===----------------------------------------------------------------------===// +// InlinerPass +//===----------------------------------------------------------------------===// + +// TODO(riverriddle) This pass should currently only be used for basic testing +// of inlining functionality. +namespace { +struct InlinerPass : public OperationPass<InlinerPass> { + void runOnOperation() override { + CallGraph &cg = getAnalysis<CallGraph>(); + Inliner inliner(&getContext(), cg); + + // Run the inline transform in post-order over the SCCs in the callgraph. + runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { + inlineCallsInSCC(inliner, scc); + }); + } }; } // end anonymous namespace -static PassRegistration<Inliner> pass("inline", "Inline function calls"); +static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 901599ce023..6ca875b25ae 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -186,6 +186,9 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, if (!shouldCloneInlinedRegion) remapInlinedOperands(newBlocks, mapper); + // Process the newly inlined blocks. + interface.processInlinedBlocks(newBlocks); + // Handle the case where only a single block was inlined. if (std::next(newBlocks.begin()) == newBlocks.end()) { // Have the interface handle the terminator of this block. diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index 3cfb5eee5c7..9732992b013 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -72,3 +72,36 @@ func @no_inline_external() { call @func_external() : () -> () return } + +// Check that multiple levels of calls will be inlined. +func @multilevel_func_a() { + return +} +func @multilevel_func_b() { + call @multilevel_func_a() : () -> () + return +} + +// CHECK-LABEL: func @inline_multilevel +func @inline_multilevel() { + // CHECK-NOT: call + %fn = "test.functional_region_op"() ({ + call @multilevel_func_b() : () -> () + "test.return"() : () -> () + }) : () -> (() -> ()) + + call_indirect %fn() : () -> () + return +} + +// Check that recursive calls are not inlined. +// CHECK-LABEL: func @no_inline_recursive +func @no_inline_recursive() { + // CHECK: test.functional_region_op + // CHECK-NOT: test.functional_region_op + %fn = "test.functional_region_op"() ({ + call @no_inline_recursive() : () -> () + "test.return"() : () -> () + }) : () -> (() -> ()) + return +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index d91bb1a2f57..ca523d8a52f 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -49,6 +49,10 @@ struct TestInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { + // Inlining into test dialect regions is legal. + return true; + } bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { return true; |