diff options
| author | River Riddle <riverriddle@google.com> | 2019-10-03 23:10:25 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-10-03 23:10:51 -0700 |
| commit | 5830f71a45df33e24c864bea4c5de070be47b488 (patch) | |
| tree | a10efe1ce637e9995cf3a4033d1b17533e6d0069 /mlir/lib/Transforms/Utils | |
| parent | a20d96e436272b52d36f52c4a07c86ed285502e9 (diff) | |
| download | bcm5719-llvm-5830f71a45df33e24c864bea4c5de070be47b488.tar.gz bcm5719-llvm-5830f71a45df33e24c864bea4c5de070be47b488.zip | |
Add support for inlining calls with different arg/result types from the callable.
Some dialects have implicit conversions inherent in their modeling, meaning that a call may have a different type that the type that the callable expects. To support this, a hook is added to the dialect interface that allows for materializing conversion operations during inlining when there is a mismatch. A hook is also added to the callable interface to allow for introspecting the expected result types.
PiperOrigin-RevId: 272814379
Diffstat (limited to 'mlir/lib/Transforms/Utils')
| -rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 134 |
1 files changed, 102 insertions, 32 deletions
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 6ca875b25ae..fd08c53b0dc 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/MapVector.h" @@ -65,8 +66,6 @@ remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks, // InlinerInterface //===----------------------------------------------------------------------===// -InlinerInterface::~InlinerInterface() {} - bool InlinerInterface::isLegalToInline( Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { // Regions can always be inlined into functions. @@ -74,7 +73,7 @@ bool InlinerInterface::isLegalToInline( return true; auto *handler = getInterfaceFor(dest->getParentOp()); - return handler ? handler->isLegalToInline(src, dest, valueMapping) : false; + return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; } bool InlinerInterface::isLegalToInline( @@ -253,38 +252,109 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, inlineLoc, shouldCloneInlinedRegion); } -/// This function inlines a FuncOp into another. This function returns failure -/// if it is not possible to inline this FuncOp. If the function returned -/// failure, then no changes to the module have been made. -/// -/// Note that this only does one level of inlining. For example, if the -/// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now -/// exists in the instruction stream. Similarly this will inline a recursive -/// FuncOp by one level. -/// -LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee, - Operation *inlinePoint, - ArrayRef<Value *> callOperands, - ArrayRef<Value *> callResults, - Location inlineLoc) { - // We don't inline if the provided callee function is a declaration. - assert(callee && "expected valid function to inline"); - if (callee.isExternal()) - return failure(); +/// Utility function used to generate a cast operation from the given interface, +/// or return nullptr if a cast could not be generated. +static Value *materializeConversion(const DialectInlinerInterface *interface, + SmallVectorImpl<Operation *> &castOps, + OpBuilder &castBuilder, Value *arg, + Type type, Location conversionLoc) { + if (!interface) + return nullptr; + + // Check to see if the interface for the call can materialize a conversion. + Operation *castOp = interface->materializeCallConversion(castBuilder, arg, + type, conversionLoc); + if (!castOp) + return nullptr; + castOps.push_back(castOp); + + // Ensure that the generated cast is correct. + assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && + castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); + return castOp->getResult(0); +} - // Verify that the provided arguments match the function arguments. - if (callOperands.size() != callee.getNumArguments()) +/// This function inlines a given region, 'src', of a callable operation, +/// 'callable', into the location defined by the given call operation. This +/// function returns failure if inlining is not possible, success otherwise. On +/// failure, no changes are made to the module. 'shouldCloneInlinedRegion' +/// corresponds to whether the source region should be cloned into the 'call' or +/// spliced directly. +LogicalResult mlir::inlineCall(InlinerInterface &interface, + CallOpInterface call, + CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) return failure(); + auto *entryBlock = &src->front(); + ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); + + // Make sure that the number of arguments and results matchup between the call + // and the region. + SmallVector<Value *, 8> callOperands(call.getArgOperands()); + SmallVector<Value *, 8> callResults(call.getOperation()->getResults()); + if (callOperands.size() != entryBlock->getNumArguments() || + callResults.size() != callableResultTypes.size()) + return failure(); + + // A set of cast operations generated to matchup the signature of the region + // with the signature of the call. + SmallVector<Operation *, 4> castOps; + castOps.reserve(callOperands.size() + callResults.size()); - // Verify that the provided values to replace match the function results. - auto funcResultTypes = callee.getType().getResults(); - if (callResults.size() != funcResultTypes.size()) + // Functor used to cleanup generated state on failure. + auto cleanupState = [&] { + for (auto *op : castOps) { + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->erase(); + } return failure(); - for (unsigned i = 0, e = callResults.size(); i != e; ++i) - if (callResults[i]->getType() != funcResultTypes[i]) - return failure(); + }; - // Call into the main region inliner function. - return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands, - callResults, inlineLoc); + // Builder used for any conversion operations that need to be materialized. + OpBuilder castBuilder(call); + Location castLoc = call.getLoc(); + auto *callInterface = interface.getInterfaceFor(call.getDialect()); + + // Map the provided call operands to the arguments of the region. + BlockAndValueMapping mapper; + for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { + BlockArgument *regionArg = entryBlock->getArgument(i); + Value *operand = callOperands[i]; + + // If the call operand doesn't match the expected region argument, try to + // generate a cast. + Type regionArgType = regionArg->getType(); + if (operand->getType() != regionArgType) { + if (!(operand = materializeConversion(callInterface, castOps, castBuilder, + operand, regionArgType, castLoc))) + return cleanupState(); + } + mapper.map(regionArg, operand); + } + + // Ensure that the resultant values of the call, match the callable. + castBuilder.setInsertionPointAfter(call); + for (unsigned i = 0, e = callResults.size(); i != e; ++i) { + Value *callResult = callResults[i]; + if (callResult->getType() == callableResultTypes[i]) + continue; + + // Generate a conversion that will produce the original type, so that the IR + // is still valid after the original call gets replaced. + Value *castResult = + materializeConversion(callInterface, castOps, castBuilder, callResult, + callResult->getType(), castLoc); + if (!castResult) + return cleanupState(); + callResult->replaceAllUsesWith(castResult); + castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); + } + + // Attempt to inline the call. + if (failed(inlineRegion(interface, src, call, mapper, callResults, + call.getLoc(), shouldCloneInlinedRegion))) + return cleanupState(); + return success(); } |

