summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-10-03 23:10:25 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-03 23:10:51 -0700
commit5830f71a45df33e24c864bea4c5de070be47b488 (patch)
treea10efe1ce637e9995cf3a4033d1b17533e6d0069 /mlir/lib/Transforms/Utils
parenta20d96e436272b52d36f52c4a07c86ed285502e9 (diff)
downloadbcm5719-llvm-5830f71a45df33e24c864bea4c5de070be47b488.tar.gz
bcm5719-llvm-5830f71a45df33e24c864bea4c5de070be47b488.zip
Add support for inlining calls with different arg/result types from the callable.
Some dialects have implicit conversions inherent in their modeling, meaning that a call may have a different type that the type that the callable expects. To support this, a hook is added to the dialect interface that allows for materializing conversion operations during inlining when there is a mismatch. A hook is also added to the callable interface to allow for introspecting the expected result types. PiperOrigin-RevId: 272814379
Diffstat (limited to 'mlir/lib/Transforms/Utils')
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp134
1 files changed, 102 insertions, 32 deletions
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 6ca875b25ae..fd08c53b0dc 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
@@ -65,8 +66,6 @@ remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks,
// InlinerInterface
//===----------------------------------------------------------------------===//
-InlinerInterface::~InlinerInterface() {}
-
bool InlinerInterface::isLegalToInline(
Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
// Regions can always be inlined into functions.
@@ -74,7 +73,7 @@ bool InlinerInterface::isLegalToInline(
return true;
auto *handler = getInterfaceFor(dest->getParentOp());
- return handler ? handler->isLegalToInline(src, dest, valueMapping) : false;
+ return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
}
bool InlinerInterface::isLegalToInline(
@@ -253,38 +252,109 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
inlineLoc, shouldCloneInlinedRegion);
}
-/// This function inlines a FuncOp into another. This function returns failure
-/// if it is not possible to inline this FuncOp. If the function returned
-/// failure, then no changes to the module have been made.
-///
-/// Note that this only does one level of inlining. For example, if the
-/// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now
-/// exists in the instruction stream. Similarly this will inline a recursive
-/// FuncOp by one level.
-///
-LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee,
- Operation *inlinePoint,
- ArrayRef<Value *> callOperands,
- ArrayRef<Value *> callResults,
- Location inlineLoc) {
- // We don't inline if the provided callee function is a declaration.
- assert(callee && "expected valid function to inline");
- if (callee.isExternal())
- return failure();
+/// Utility function used to generate a cast operation from the given interface,
+/// or return nullptr if a cast could not be generated.
+static Value *materializeConversion(const DialectInlinerInterface *interface,
+ SmallVectorImpl<Operation *> &castOps,
+ OpBuilder &castBuilder, Value *arg,
+ Type type, Location conversionLoc) {
+ if (!interface)
+ return nullptr;
+
+ // Check to see if the interface for the call can materialize a conversion.
+ Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
+ type, conversionLoc);
+ if (!castOp)
+ return nullptr;
+ castOps.push_back(castOp);
+
+ // Ensure that the generated cast is correct.
+ assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
+ castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
+ return castOp->getResult(0);
+}
- // Verify that the provided arguments match the function arguments.
- if (callOperands.size() != callee.getNumArguments())
+/// This function inlines a given region, 'src', of a callable operation,
+/// 'callable', into the location defined by the given call operation. This
+/// function returns failure if inlining is not possible, success otherwise. On
+/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
+/// corresponds to whether the source region should be cloned into the 'call' or
+/// spliced directly.
+LogicalResult mlir::inlineCall(InlinerInterface &interface,
+ CallOpInterface call,
+ CallableOpInterface callable, Region *src,
+ bool shouldCloneInlinedRegion) {
+ // We expect the region to have at least one block.
+ if (src->empty())
return failure();
+ auto *entryBlock = &src->front();
+ ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
+
+ // Make sure that the number of arguments and results matchup between the call
+ // and the region.
+ SmallVector<Value *, 8> callOperands(call.getArgOperands());
+ SmallVector<Value *, 8> callResults(call.getOperation()->getResults());
+ if (callOperands.size() != entryBlock->getNumArguments() ||
+ callResults.size() != callableResultTypes.size())
+ return failure();
+
+ // A set of cast operations generated to matchup the signature of the region
+ // with the signature of the call.
+ SmallVector<Operation *, 4> castOps;
+ castOps.reserve(callOperands.size() + callResults.size());
- // Verify that the provided values to replace match the function results.
- auto funcResultTypes = callee.getType().getResults();
- if (callResults.size() != funcResultTypes.size())
+ // Functor used to cleanup generated state on failure.
+ auto cleanupState = [&] {
+ for (auto *op : castOps) {
+ op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
+ op->erase();
+ }
return failure();
- for (unsigned i = 0, e = callResults.size(); i != e; ++i)
- if (callResults[i]->getType() != funcResultTypes[i])
- return failure();
+ };
- // Call into the main region inliner function.
- return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands,
- callResults, inlineLoc);
+ // Builder used for any conversion operations that need to be materialized.
+ OpBuilder castBuilder(call);
+ Location castLoc = call.getLoc();
+ auto *callInterface = interface.getInterfaceFor(call.getDialect());
+
+ // Map the provided call operands to the arguments of the region.
+ BlockAndValueMapping mapper;
+ for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
+ BlockArgument *regionArg = entryBlock->getArgument(i);
+ Value *operand = callOperands[i];
+
+ // If the call operand doesn't match the expected region argument, try to
+ // generate a cast.
+ Type regionArgType = regionArg->getType();
+ if (operand->getType() != regionArgType) {
+ if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
+ operand, regionArgType, castLoc)))
+ return cleanupState();
+ }
+ mapper.map(regionArg, operand);
+ }
+
+ // Ensure that the resultant values of the call, match the callable.
+ castBuilder.setInsertionPointAfter(call);
+ for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
+ Value *callResult = callResults[i];
+ if (callResult->getType() == callableResultTypes[i])
+ continue;
+
+ // Generate a conversion that will produce the original type, so that the IR
+ // is still valid after the original call gets replaced.
+ Value *castResult =
+ materializeConversion(callInterface, castOps, castBuilder, callResult,
+ callResult->getType(), castLoc);
+ if (!castResult)
+ return cleanupState();
+ callResult->replaceAllUsesWith(castResult);
+ castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
+ }
+
+ // Attempt to inline the call.
+ if (failed(inlineRegion(interface, src, call, mapper, callResults,
+ call.getLoc(), shouldCloneInlinedRegion)))
+ return cleanupState();
+ return success();
}
OpenPOWER on IntegriCloud