diff options
| author | Jacques Pienaar <jpienaar@google.com> | 2019-12-06 14:42:16 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-06 14:42:45 -0800 |
| commit | 4add9edd7212f9e0b51552250cee606b4d93a043 (patch) | |
| tree | 22201ca96b94a591d268231ecec5d61dab2e1b59 /mlir/test/lib/TestDialect | |
| parent | e96150eb46d8c381f11a7333f0384aad0fc8d1b6 (diff) | |
| download | bcm5719-llvm-4add9edd7212f9e0b51552250cee606b4d93a043.tar.gz bcm5719-llvm-4add9edd7212f9e0b51552250cee606b4d93a043.zip | |
Change inferReturnTypes to return LogicalResult and values
Previously the error case was using a sentinel in the error case which was bad. Also make the one `build` invoke the other `build` to reuse verification there.
And follow up on suggestion to use formatv which I missed during previous review.
PiperOrigin-RevId: 284265762
Diffstat (limited to 'mlir/test/lib/TestDialect')
| -rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.cpp | 16 | ||||
| -rw-r--r-- | mlir/test/lib/TestDialect/TestPatterns.cpp | 11 |
2 files changed, 15 insertions, 12 deletions
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 60a16d968dc..8b9f9a9874a 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -289,17 +289,17 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold( return success(); } -SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( +LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( llvm::Optional<Location> location, ArrayRef<Value *> operands, - ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) { + ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions, + SmallVectorImpl<Type> &inferedReturnTypes) { if (operands[0]->getType() != operands[1]->getType()) { - if (location) - mlir::emitError(*location) - << "operand type mismatch " << operands[0]->getType() << " vs " - << operands[1]->getType(); - return {nullptr}; + return emitOptionalError(location, "operand type mismatch ", + operands[0]->getType(), " vs ", + operands[1]->getType()); } - return {operands[0]->getType()}; + inferedReturnTypes.assign({operands[0]->getType()}); + return success(); } // Static initialization for Test dialect registration. diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 7b835c5e61d..06911d2b833 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -74,10 +74,13 @@ struct ReturnTypeOpMatch : public RewritePattern { PatternRewriter &rewriter) const final { if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) { SmallVector<Value *, 4> values(op->getOperands()); - auto res = retTypeFn.inferReturnTypes(op->getLoc(), values, - op->getAttrs(), op->getRegions()); - SmallVector<Type, 1> result_types(op->getResultTypes()); - if (!retTypeFn.isCompatibleReturnTypes(res, result_types)) + SmallVector<Type, 2> inferedReturnTypes; + if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values, + op->getAttrs(), op->getRegions(), + inferedReturnTypes))) + return matchFailure(); + SmallVector<Type, 1> resultTypes(op->getResultTypes()); + if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) return op->emitOpError( "inferred type incompatible with return type of operation"), matchFailure(); |

