summaryrefslogtreecommitdiffstats
path: root/mlir/test/lib/TestDialect
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2019-12-06 14:42:16 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-06 14:42:45 -0800
commit4add9edd7212f9e0b51552250cee606b4d93a043 (patch)
tree22201ca96b94a591d268231ecec5d61dab2e1b59 /mlir/test/lib/TestDialect
parente96150eb46d8c381f11a7333f0384aad0fc8d1b6 (diff)
downloadbcm5719-llvm-4add9edd7212f9e0b51552250cee606b4d93a043.tar.gz
bcm5719-llvm-4add9edd7212f9e0b51552250cee606b4d93a043.zip
Change inferReturnTypes to return LogicalResult and values
Previously the error case was using a sentinel in the error case which was bad. Also make the one `build` invoke the other `build` to reuse verification there. And follow up on suggestion to use formatv which I missed during previous review. PiperOrigin-RevId: 284265762
Diffstat (limited to 'mlir/test/lib/TestDialect')
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp16
-rw-r--r--mlir/test/lib/TestDialect/TestPatterns.cpp11
2 files changed, 15 insertions, 12 deletions
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 60a16d968dc..8b9f9a9874a 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -289,17 +289,17 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
return success();
}
-SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
+LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ArrayRef<Value *> operands,
- ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
+ ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions,
+ SmallVectorImpl<Type> &inferedReturnTypes) {
if (operands[0]->getType() != operands[1]->getType()) {
- if (location)
- mlir::emitError(*location)
- << "operand type mismatch " << operands[0]->getType() << " vs "
- << operands[1]->getType();
- return {nullptr};
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0]->getType(), " vs ",
+ operands[1]->getType());
}
- return {operands[0]->getType()};
+ inferedReturnTypes.assign({operands[0]->getType()});
+ return success();
}
// Static initialization for Test dialect registration.
diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index 7b835c5e61d..06911d2b833 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -74,10 +74,13 @@ struct ReturnTypeOpMatch : public RewritePattern {
PatternRewriter &rewriter) const final {
if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
SmallVector<Value *, 4> values(op->getOperands());
- auto res = retTypeFn.inferReturnTypes(op->getLoc(), values,
- op->getAttrs(), op->getRegions());
- SmallVector<Type, 1> result_types(op->getResultTypes());
- if (!retTypeFn.isCompatibleReturnTypes(res, result_types))
+ SmallVector<Type, 2> inferedReturnTypes;
+ if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values,
+ op->getAttrs(), op->getRegions(),
+ inferedReturnTypes)))
+ return matchFailure();
+ SmallVector<Type, 1> resultTypes(op->getResultTypes());
+ if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
return op->emitOpError(
"inferred type incompatible with return type of operation"),
matchFailure();
OpenPOWER on IntegriCloud