diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/InliningUtils.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 356 |
1 files changed, 356 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp new file mode 100644 index 00000000000..1ac286c67fb --- /dev/null +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -0,0 +1,356 @@ +//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous inlining utilities. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/InliningUtils.h" + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "inlining" + +using namespace mlir; + +/// Remap locations from the inlined blocks with CallSiteLoc locations with the +/// provided caller location. +static void +remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, + Location callerLoc) { + DenseMap<Location, Location> mappedLocations; + auto remapOpLoc = [&](Operation *op) { + auto it = mappedLocations.find(op->getLoc()); + if (it == mappedLocations.end()) { + auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); + it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; + } + op->setLoc(it->second); + }; + for (auto &block : inlinedBlocks) + block.walk(remapOpLoc); +} + +static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, + BlockAndValueMapping &mapper) { + auto remapOperands = [&](Operation *op) { + for (auto &operand : op->getOpOperands()) + if (auto mappedOp = mapper.lookupOrNull(operand.get())) + operand.set(mappedOp); + }; + for (auto &block : inlinedBlocks) + block.walk(remapOperands); +} + +//===----------------------------------------------------------------------===// +// InlinerInterface +//===----------------------------------------------------------------------===// + +bool InlinerInterface::isLegalToInline( + Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { + // Regions can always be inlined into functions. + if (isa<FuncOp>(dest->getParentOp())) + return true; + + auto *handler = getInterfaceFor(dest->getParentOp()); + return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; +} + +bool InlinerInterface::isLegalToInline( + Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { + auto *handler = getInterfaceFor(op); + return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; +} + +bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { + auto *handler = getInterfaceFor(op); + return handler ? handler->shouldAnalyzeRecursively(op) : true; +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { + auto *handler = getInterfaceFor(op); + assert(handler && "expected valid dialect handler"); + handler->handleTerminator(op, newDest); +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void InlinerInterface::handleTerminator(Operation *op, + ArrayRef<Value> valuesToRepl) const { + auto *handler = getInterfaceFor(op); + assert(handler && "expected valid dialect handler"); + handler->handleTerminator(op, valuesToRepl); +} + +/// Utility to check that all of the operations within 'src' can be inlined. +static bool isLegalToInline(InlinerInterface &interface, Region *src, + Region *insertRegion, + BlockAndValueMapping &valueMapping) { + for (auto &block : *src) { + for (auto &op : block) { + // Check this operation. + if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) { + LLVM_DEBUG({ + llvm::dbgs() << "* Illegal to inline because of op: "; + op.dump(); + }); + return false; + } + // Check any nested regions. + if (interface.shouldAnalyzeRecursively(&op) && + llvm::any_of(op.getRegions(), [&](Region ®ion) { + return !isLegalToInline(interface, ®ion, insertRegion, + valueMapping); + })) + return false; + } + } + return true; +} + +//===----------------------------------------------------------------------===// +// Inline Methods +//===----------------------------------------------------------------------===// + +LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, + BlockAndValueMapping &mapper, + ArrayRef<Value> resultsToReplace, + Optional<Location> inlineLoc, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) + return failure(); + + // Check that all of the region arguments have been mapped. + auto *srcEntryBlock = &src->front(); + if (llvm::any_of(srcEntryBlock->getArguments(), + [&](BlockArgument arg) { return !mapper.contains(arg); })) + return failure(); + + // The insertion point must be within a block. + Block *insertBlock = inlinePoint->getBlock(); + if (!insertBlock) + return failure(); + Region *insertRegion = insertBlock->getParent(); + + // Check that the operations within the source region are valid to inline. + if (!interface.isLegalToInline(insertRegion, src, mapper) || + !isLegalToInline(interface, src, insertRegion, mapper)) + return failure(); + + // Split the insertion block. + Block *postInsertBlock = + insertBlock->splitBlock(++inlinePoint->getIterator()); + + // Check to see if the region is being cloned, or moved inline. In either + // case, move the new blocks after the 'insertBlock' to improve IR + // readability. + if (shouldCloneInlinedRegion) + src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); + else + insertRegion->getBlocks().splice(postInsertBlock->getIterator(), + src->getBlocks(), src->begin(), + src->end()); + + // Get the range of newly inserted blocks. + auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), + postInsertBlock->getIterator()); + Block *firstNewBlock = &*newBlocks.begin(); + + // Remap the locations of the inlined operations if a valid source location + // was provided. + if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) + remapInlinedLocations(newBlocks, *inlineLoc); + + // If the blocks were moved in-place, make sure to remap any necessary + // operands. + if (!shouldCloneInlinedRegion) + remapInlinedOperands(newBlocks, mapper); + + // Process the newly inlined blocks. + interface.processInlinedBlocks(newBlocks); + + // Handle the case where only a single block was inlined. + if (std::next(newBlocks.begin()) == newBlocks.end()) { + // Have the interface handle the terminator of this block. + auto *firstBlockTerminator = firstNewBlock->getTerminator(); + interface.handleTerminator(firstBlockTerminator, resultsToReplace); + firstBlockTerminator->erase(); + + // Merge the post insert block into the cloned entry block. + firstNewBlock->getOperations().splice(firstNewBlock->end(), + postInsertBlock->getOperations()); + postInsertBlock->erase(); + } else { + // Otherwise, there were multiple blocks inlined. Add arguments to the post + // insertion block to represent the results to replace. + for (Value resultToRepl : resultsToReplace) { + resultToRepl->replaceAllUsesWith( + postInsertBlock->addArgument(resultToRepl->getType())); + } + + /// Handle the terminators for each of the new blocks. + for (auto &newBlock : newBlocks) + interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); + } + + // Splice the instructions of the inlined entry block into the insert block. + insertBlock->getOperations().splice(insertBlock->end(), + firstNewBlock->getOperations()); + firstNewBlock->erase(); + return success(); +} + +/// This function is an overload of the above 'inlineRegion' that allows for +/// providing the set of operands ('inlinedOperands') that should be used +/// in-favor of the region arguments when inlining. +LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, + ArrayRef<Value> inlinedOperands, + ArrayRef<Value> resultsToReplace, + Optional<Location> inlineLoc, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) + return failure(); + + auto *entryBlock = &src->front(); + if (inlinedOperands.size() != entryBlock->getNumArguments()) + return failure(); + + // Map the provided call operands to the arguments of the region. + BlockAndValueMapping mapper; + for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { + // Verify that the types of the provided values match the function argument + // types. + BlockArgument regionArg = entryBlock->getArgument(i); + if (inlinedOperands[i]->getType() != regionArg->getType()) + return failure(); + mapper.map(regionArg, inlinedOperands[i]); + } + + // Call into the main region inliner function. + return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, + inlineLoc, shouldCloneInlinedRegion); +} + +/// Utility function used to generate a cast operation from the given interface, +/// or return nullptr if a cast could not be generated. +static Value materializeConversion(const DialectInlinerInterface *interface, + SmallVectorImpl<Operation *> &castOps, + OpBuilder &castBuilder, Value arg, Type type, + Location conversionLoc) { + if (!interface) + return nullptr; + + // Check to see if the interface for the call can materialize a conversion. + Operation *castOp = interface->materializeCallConversion(castBuilder, arg, + type, conversionLoc); + if (!castOp) + return nullptr; + castOps.push_back(castOp); + + // Ensure that the generated cast is correct. + assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && + castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); + return castOp->getResult(0); +} + +/// This function inlines a given region, 'src', of a callable operation, +/// 'callable', into the location defined by the given call operation. This +/// function returns failure if inlining is not possible, success otherwise. On +/// failure, no changes are made to the module. 'shouldCloneInlinedRegion' +/// corresponds to whether the source region should be cloned into the 'call' or +/// spliced directly. +LogicalResult mlir::inlineCall(InlinerInterface &interface, + CallOpInterface call, + CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) + return failure(); + auto *entryBlock = &src->front(); + ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); + + // Make sure that the number of arguments and results matchup between the call + // and the region. + SmallVector<Value, 8> callOperands(call.getArgOperands()); + SmallVector<Value, 8> callResults(call.getOperation()->getResults()); + if (callOperands.size() != entryBlock->getNumArguments() || + callResults.size() != callableResultTypes.size()) + return failure(); + + // A set of cast operations generated to matchup the signature of the region + // with the signature of the call. + SmallVector<Operation *, 4> castOps; + castOps.reserve(callOperands.size() + callResults.size()); + + // Functor used to cleanup generated state on failure. + auto cleanupState = [&] { + for (auto *op : castOps) { + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->erase(); + } + return failure(); + }; + + // Builder used for any conversion operations that need to be materialized. + OpBuilder castBuilder(call); + Location castLoc = call.getLoc(); + auto *callInterface = interface.getInterfaceFor(call.getDialect()); + + // Map the provided call operands to the arguments of the region. + BlockAndValueMapping mapper; + for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { + BlockArgument regionArg = entryBlock->getArgument(i); + Value operand = callOperands[i]; + + // If the call operand doesn't match the expected region argument, try to + // generate a cast. + Type regionArgType = regionArg->getType(); + if (operand->getType() != regionArgType) { + if (!(operand = materializeConversion(callInterface, castOps, castBuilder, + operand, regionArgType, castLoc))) + return cleanupState(); + } + mapper.map(regionArg, operand); + } + + // Ensure that the resultant values of the call, match the callable. + castBuilder.setInsertionPointAfter(call); + for (unsigned i = 0, e = callResults.size(); i != e; ++i) { + Value callResult = callResults[i]; + if (callResult->getType() == callableResultTypes[i]) + continue; + + // Generate a conversion that will produce the original type, so that the IR + // is still valid after the original call gets replaced. + Value castResult = + materializeConversion(callInterface, castOps, castBuilder, callResult, + callResult->getType(), castLoc); + if (!castResult) + return cleanupState(); + callResult->replaceAllUsesWith(castResult); + castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); + } + + // Attempt to inline the call. + if (failed(inlineRegion(interface, src, call, mapper, callResults, + call.getLoc(), shouldCloneInlinedRegion))) + return cleanupState(); + return success(); +} |