summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-10-03 23:10:25 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-10-03 23:10:51 -0700
commit5830f71a45df33e24c864bea4c5de070be47b488 (patch)
treea10efe1ce637e9995cf3a4033d1b17533e6d0069 /mlir
parenta20d96e436272b52d36f52c4a07c86ed285502e9 (diff)
downloadbcm5719-llvm-5830f71a45df33e24c864bea4c5de070be47b488.tar.gz
bcm5719-llvm-5830f71a45df33e24c864bea4c5de070be47b488.zip
Add support for inlining calls with different arg/result types from the callable.
Some dialects have implicit conversions inherent in their modeling, meaning that a call may have a different type that the type that the callable expects. To support this, a hook is added to the dialect interface that allows for materializing conversion operations during inlining when there is a mismatch. A hook is also added to the callable interface to allow for introspecting the expected result types. PiperOrigin-RevId: 272814379
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/CallInterfaces.td8
-rw-r--r--mlir/include/mlir/IR/Function.h7
-rw-r--r--mlir/include/mlir/Transforms/InliningUtils.h52
-rw-r--r--mlir/lib/Transforms/Inliner.cpp8
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp134
-rw-r--r--mlir/test/Transforms/inlining.mlir36
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp17
-rw-r--r--mlir/test/lib/TestDialect/TestOps.td23
8 files changed, 228 insertions, 57 deletions
diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td
index fca7773ce63..3ed802f07f4 100644
--- a/mlir/include/mlir/Analysis/CallInterfaces.td
+++ b/mlir/include/mlir/Analysis/CallInterfaces.td
@@ -80,11 +80,17 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
"Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable)
>,
InterfaceMethod<[{
- Returns all of the callable regions of this operation
+ Returns all of the callable regions of this operation.
}],
"void", "getCallableRegions",
(ins "SmallVectorImpl<Region *> &":$callables)
>,
+ InterfaceMethod<[{
+ Returns the results types that the given callable region produces when
+ executed.
+ }],
+ "ArrayRef<Type>", "getCallableResults", (ins "Region *":$callable)
+ >,
];
}
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 6bf6e65c38c..95920b38c14 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -128,6 +128,13 @@ public:
callables.push_back(&getBody());
}
+ /// Returns the results types that the given callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults(Region *region) {
+ assert(!isExternal() && region == &getBody() && "invalid callable");
+ return getType().getResults();
+ }
+
private:
// This trait needs access to `getNumFuncArguments` and `verifyType` hooks
// defined below.
diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index 7fe67e78127..fd12624886d 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -30,7 +30,10 @@ namespace mlir {
class Block;
class BlockAndValueMapping;
+class CallableOpInterface;
+class CallOpInterface;
class FuncOp;
+class OpBuilder;
class Operation;
class Region;
class Value;
@@ -106,6 +109,27 @@ public:
llvm_unreachable(
"must implement handleTerminator in the case of one inlined block");
}
+
+ /// Attempt to materialize a conversion for a type mismatch between a call
+ /// from this dialect, and a callable region. This method should generate an
+ /// operation that takes 'input' as the only operand, and produces a single
+ /// result of 'resultType'. If a conversion can not be generated, nullptr
+ /// should be returned. For example, this hook may be invoked in the following
+ /// scenarios:
+ /// func @foo(i32) -> i32 { ... }
+ ///
+ /// // Mismatched input operand
+ /// ... = foo.call @foo(%input : i16) -> i32
+ ///
+ /// // Mismatched result type.
+ /// ... = foo.call @foo(%input : i32) -> i16
+ ///
+ /// NOTE: This hook may be invoked before the 'isLegal' checks above.
+ virtual Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Type resultType,
+ Location conversionLoc) const {
+ return nullptr;
+ }
};
/// This interface provides the hooks into the inlining interface.
@@ -115,7 +139,6 @@ class InlinerInterface
: public DialectInterfaceCollection<DialectInlinerInterface> {
public:
using Base::Base;
- virtual ~InlinerInterface();
/// Process a set of blocks that have been inlined. This callback is invoked
/// *before* inlined terminator operations have been processed.
@@ -178,24 +201,15 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
llvm::Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
-/// This function inlines a FuncOp into another. This function returns failure
-/// if it is not possible to inline this FuncOp. If the function returned
-/// failure, then no changes to the module have been made.
-///
-/// Note that this only does one level of inlining. For example, if the
-/// instruction 'call B' is inlined into function 'A', and function 'B' also
-/// calls 'C', then the call to 'C' now exists inside the body of 'A'. Similarly
-/// this will inline a recursive FuncOp by one level.
-///
-/// 'callOperands' must correspond, 1-1, with the arguments to the provided
-/// FuncOp. 'callResults' must correspond, 1-1, with the results of the
-/// provided FuncOp. These results will be replaced by the operands of any
-/// return operations that are inlined. 'inlineLoc' should refer to the location
-/// that the FuncOp is being inlined into.
-LogicalResult inlineFunction(InlinerInterface &interface, FuncOp callee,
- Operation *inlinePoint,
- ArrayRef<Value *> callOperands,
- ArrayRef<Value *> callResults, Location inlineLoc);
+/// This function inlines a given region, 'src', of a callable operation,
+/// 'callable', into the location defined by the given call operation. This
+/// function returns failure if inlining is not possible, success otherwise. On
+/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
+/// corresponds to whether the source region should be cloned into the 'call' or
+/// spliced directly.
+LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
+ CallableOpInterface callable, Region *src,
+ bool shouldCloneInlinedRegion = true);
} // end namespace mlir
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index afb2dccc241..c5defa5b3a6 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -157,10 +157,10 @@ static void inlineCallsInSCC(Inliner &inliner,
continue;
CallOpInterface call = it.call;
- LogicalResult inlineResult = inlineRegion(
- inliner, it.targetNode->getCallableRegion(), call,
- llvm::to_vector<8>(call.getArgOperands()),
- llvm::to_vector<8>(call.getOperation()->getResults()), call.getLoc());
+ Region *targetRegion = it.targetNode->getCallableRegion();
+ LogicalResult inlineResult = inlineCall(
+ inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
+ targetRegion);
if (failed(inlineResult))
continue;
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 6ca875b25ae..fd08c53b0dc 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
@@ -65,8 +66,6 @@ remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks,
// InlinerInterface
//===----------------------------------------------------------------------===//
-InlinerInterface::~InlinerInterface() {}
-
bool InlinerInterface::isLegalToInline(
Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
// Regions can always be inlined into functions.
@@ -74,7 +73,7 @@ bool InlinerInterface::isLegalToInline(
return true;
auto *handler = getInterfaceFor(dest->getParentOp());
- return handler ? handler->isLegalToInline(src, dest, valueMapping) : false;
+ return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
}
bool InlinerInterface::isLegalToInline(
@@ -253,38 +252,109 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
inlineLoc, shouldCloneInlinedRegion);
}
-/// This function inlines a FuncOp into another. This function returns failure
-/// if it is not possible to inline this FuncOp. If the function returned
-/// failure, then no changes to the module have been made.
-///
-/// Note that this only does one level of inlining. For example, if the
-/// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now
-/// exists in the instruction stream. Similarly this will inline a recursive
-/// FuncOp by one level.
-///
-LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee,
- Operation *inlinePoint,
- ArrayRef<Value *> callOperands,
- ArrayRef<Value *> callResults,
- Location inlineLoc) {
- // We don't inline if the provided callee function is a declaration.
- assert(callee && "expected valid function to inline");
- if (callee.isExternal())
- return failure();
+/// Utility function used to generate a cast operation from the given interface,
+/// or return nullptr if a cast could not be generated.
+static Value *materializeConversion(const DialectInlinerInterface *interface,
+ SmallVectorImpl<Operation *> &castOps,
+ OpBuilder &castBuilder, Value *arg,
+ Type type, Location conversionLoc) {
+ if (!interface)
+ return nullptr;
+
+ // Check to see if the interface for the call can materialize a conversion.
+ Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
+ type, conversionLoc);
+ if (!castOp)
+ return nullptr;
+ castOps.push_back(castOp);
+
+ // Ensure that the generated cast is correct.
+ assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
+ castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
+ return castOp->getResult(0);
+}
- // Verify that the provided arguments match the function arguments.
- if (callOperands.size() != callee.getNumArguments())
+/// This function inlines a given region, 'src', of a callable operation,
+/// 'callable', into the location defined by the given call operation. This
+/// function returns failure if inlining is not possible, success otherwise. On
+/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
+/// corresponds to whether the source region should be cloned into the 'call' or
+/// spliced directly.
+LogicalResult mlir::inlineCall(InlinerInterface &interface,
+ CallOpInterface call,
+ CallableOpInterface callable, Region *src,
+ bool shouldCloneInlinedRegion) {
+ // We expect the region to have at least one block.
+ if (src->empty())
return failure();
+ auto *entryBlock = &src->front();
+ ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
+
+ // Make sure that the number of arguments and results matchup between the call
+ // and the region.
+ SmallVector<Value *, 8> callOperands(call.getArgOperands());
+ SmallVector<Value *, 8> callResults(call.getOperation()->getResults());
+ if (callOperands.size() != entryBlock->getNumArguments() ||
+ callResults.size() != callableResultTypes.size())
+ return failure();
+
+ // A set of cast operations generated to matchup the signature of the region
+ // with the signature of the call.
+ SmallVector<Operation *, 4> castOps;
+ castOps.reserve(callOperands.size() + callResults.size());
- // Verify that the provided values to replace match the function results.
- auto funcResultTypes = callee.getType().getResults();
- if (callResults.size() != funcResultTypes.size())
+ // Functor used to cleanup generated state on failure.
+ auto cleanupState = [&] {
+ for (auto *op : castOps) {
+ op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
+ op->erase();
+ }
return failure();
- for (unsigned i = 0, e = callResults.size(); i != e; ++i)
- if (callResults[i]->getType() != funcResultTypes[i])
- return failure();
+ };
- // Call into the main region inliner function.
- return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands,
- callResults, inlineLoc);
+ // Builder used for any conversion operations that need to be materialized.
+ OpBuilder castBuilder(call);
+ Location castLoc = call.getLoc();
+ auto *callInterface = interface.getInterfaceFor(call.getDialect());
+
+ // Map the provided call operands to the arguments of the region.
+ BlockAndValueMapping mapper;
+ for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
+ BlockArgument *regionArg = entryBlock->getArgument(i);
+ Value *operand = callOperands[i];
+
+ // If the call operand doesn't match the expected region argument, try to
+ // generate a cast.
+ Type regionArgType = regionArg->getType();
+ if (operand->getType() != regionArgType) {
+ if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
+ operand, regionArgType, castLoc)))
+ return cleanupState();
+ }
+ mapper.map(regionArg, operand);
+ }
+
+ // Ensure that the resultant values of the call, match the callable.
+ castBuilder.setInsertionPointAfter(call);
+ for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
+ Value *callResult = callResults[i];
+ if (callResult->getType() == callableResultTypes[i])
+ continue;
+
+ // Generate a conversion that will produce the original type, so that the IR
+ // is still valid after the original call gets replaced.
+ Value *castResult =
+ materializeConversion(callInterface, castOps, castBuilder, callResult,
+ callResult->getType(), castLoc);
+ if (!castResult)
+ return cleanupState();
+ callResult->replaceAllUsesWith(castResult);
+ castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
+ }
+
+ // Attempt to inline the call.
+ if (failed(inlineRegion(interface, src, call, mapper, callResults,
+ call.getLoc(), shouldCloneInlinedRegion)))
+ return cleanupState();
+ return success();
}
diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index 9732992b013..4d855d02b8f 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -105,3 +105,39 @@ func @no_inline_recursive() {
}) : () -> (() -> ())
return
}
+
+// Check that we can convert types for inputs and results as necessary.
+func @convert_callee_fn(%arg : i32) -> i32 {
+ return %arg : i32
+}
+func @convert_callee_fn_multi_arg(%a : i32, %b : i32) -> () {
+ return
+}
+func @convert_callee_fn_multi_res() -> (i32, i32) {
+ %res = constant 0 : i32
+ return %res, %res : i32, i32
+}
+
+// CHECK-LABEL: func @inline_convert_call
+func @inline_convert_call() -> i16 {
+ // CHECK: %[[INPUT:.*]] = constant
+ %test_input = constant 0 : i16
+
+ // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[INPUT]]) : (i16) -> i32
+ // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CAST_INPUT]]) : (i32) -> i16
+ // CHECK-NEXT: return %[[CAST_RESULT]]
+ %res = "test.conversion_call_op"(%test_input) { callee=@convert_callee_fn } : (i16) -> (i16)
+ return %res : i16
+}
+
+// CHECK-LABEL: func @no_inline_convert_call
+func @no_inline_convert_call() {
+ // CHECK: "test.conversion_call_op"
+ %test_input_i16 = constant 0 : i16
+ %test_input_i64 = constant 0 : i64
+ "test.conversion_call_op"(%test_input_i16, %test_input_i64) { callee=@convert_callee_fn_multi_arg } : (i16, i64) -> ()
+
+ // CHECK: "test.conversion_call_op"
+ %res_2:2 = "test.conversion_call_op"() { callee=@convert_callee_fn_multi_res } : () -> (i16, i64)
+ return
+}
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index ca523d8a52f..78a75f0a3d9 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -58,7 +58,7 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return true;
}
- bool shouldAnalyzeRecursively(Operation *op) const override {
+ bool shouldAnalyzeRecursively(Operation *op) const final {
// Analyze recursively if this is not a functional region operation, it
// froms a separate functional scope.
return !isa<FunctionalRegionOp>(op);
@@ -82,6 +82,21 @@ struct TestInlinerInterface : public DialectInlinerInterface {
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
+
+ /// Attempt to materialize a conversion for a type mismatch between a call
+ /// from this dialect, and a callable region. This method should generate an
+ /// operation that takes 'input' as the only operand, and produces a single
+ /// result of 'resultType'. If a conversion can not be generated, nullptr
+ /// should be returned.
+ Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Type resultType,
+ Location conversionLoc) const final {
+ // Only allow conversion for i16/i32 types.
+ if (!(resultType.isInteger(16) || resultType.isInteger(32)) ||
+ !(input->getType().isInteger(16) || input->getType().isInteger(32)))
+ return nullptr;
+ return builder.create<TestCastOp>(conversionLoc, resultType, input);
+ }
};
} // end anonymous namespace
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 944ce79a182..41e44f69c2f 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -194,6 +194,26 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
}
+//===----------------------------------------------------------------------===//
+// Test Call Interfaces
+//===----------------------------------------------------------------------===//
+
+def ConversionCallOp : TEST_Op<"conversion_call_op",
+ [CallOpInterface]> {
+ let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
+ let results = (outs Variadic<AnyType>);
+
+ let extraClassDeclaration = [{
+ /// Get the argument operands to the called function.
+ operand_range getArgOperands() { return inputs(); }
+
+ /// Return the callee of this operation.
+ CallInterfaceCallable getCallableForCallee() {
+ return getAttrOfType<SymbolRefAttr>("callee");
+ }
+ }];
+}
+
def FunctionalRegionOp : TEST_Op<"functional_region_op",
[CallableOpInterface]> {
let regions = (region AnyRegion:$body);
@@ -204,6 +224,9 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
void getCallableRegions(SmallVectorImpl<Region *> &callables) {
callables.push_back(&body());
}
+ ArrayRef<Type> getCallableResults(Region *) {
+ return getType().cast<FunctionType>().getResults();
+ }
}];
}
OpenPOWER on IntegriCloud