diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-04 15:49:09 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-04 15:49:42 -0800 |
| commit | 2c930f8d9daa04576ac49e14c19dc540ebf823fe (patch) | |
| tree | 6dcb8515f827477ef49752a5d19818404988e53c /mlir/lib/IR | |
| parent | b3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2 (diff) | |
| download | bcm5719-llvm-2c930f8d9daa04576ac49e14c19dc540ebf823fe.tar.gz bcm5719-llvm-2c930f8d9daa04576ac49e14c19dc540ebf823fe.zip | |
Add emitOptional(Error|Warning|Remark) functions to simplify emission with an optional location.
In some situations a diagnostic may optionally be emitted by the presence of a location, e.g. attribute and type verification. These situations currently require extra 'if(loc) emitError(...); return failure()' wrappers that make verification clunky. These new overloads take an optional location and a list of arguments to the diagnostic, and return a LogicalResult. We take the arguments directly and return LogicalResult instead of returning InFlightDiagnostic because we cannot create a valid diagnostic with a null location. This creates an awkward situation where a user may try to treat the, potentially null, diagnostic as a valid one and encounter crashes when attaching notes/etc. Below is an example of how these methods simplify some existing usages:
Before:
if (loc)
emitError(*loc, "this is my diagnostic with argument: ") << 5;
return failure();
After:
return emitOptionalError(loc, "this is my diagnostic with argument: ", 5);
PiperOrigin-RevId: 283853599
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/Attributes.cpp | 43 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardTypes.cpp | 82 | ||||
| -rw-r--r-- | mlir/lib/IR/Types.cpp | 14 |
3 files changed, 56 insertions, 83 deletions
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 5d7a4f08d1e..f2f3d41f980 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -214,35 +214,31 @@ double FloatAttr::getValueAsDouble(APFloat value) { } /// Verify construction invariants. -static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc, +static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc, Type type) { - if (!type.isa<FloatType>()) { - if (loc) - emitError(*loc, "expected floating point type"); - return failure(); - } + if (!type.isa<FloatType>()) + return emitOptionalError(loc, "expected floating point type"); return success(); } -LogicalResult FloatAttr::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) { +LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *ctx, + Type type, double value) { return verifyFloatTypeInvariants(loc, type); } -LogicalResult -FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc, - MLIRContext *ctx, Type type, - const APFloat &value) { +LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *ctx, + Type type, + const APFloat &value) { // Verify that the type is correct. if (failed(verifyFloatTypeInvariants(loc, type))) return failure(); // Verify that the type semantics match that of the value. if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { - if (loc) - emitError(*loc, - "FloatAttr type doesn't match the type implied by its value"); - return failure(); + return emitOptionalError( + loc, "FloatAttr type doesn't match the type implied by its value"); } return success(); } @@ -330,14 +326,13 @@ Identifier OpaqueAttr::getDialectNamespace() const { StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } /// Verify the construction of an opaque attribute. -LogicalResult OpaqueAttr::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect, - StringRef attrData, Type type) { - if (!Dialect::isValidNamespace(dialect.strref())) { - if (loc) - emitError(*loc) << "invalid dialect namespace '" << dialect << "'"; - return failure(); - } +LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *context, + Identifier dialect, + StringRef attrData, + Type type) { + if (!Dialect::isValidNamespace(dialect.strref())) + return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); return success(); } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 4347856de36..8a4b51f215a 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -61,13 +61,12 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); } constexpr unsigned IntegerType::kMaxWidth; /// Verify the construction of an integer type. -LogicalResult IntegerType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, unsigned width) { +LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *context, + unsigned width) { if (width > IntegerType::kMaxWidth) { - if (loc) - emitError(*loc) << "integer bitwidth is limited to " - << IntegerType::kMaxWidth << " bits"; - return failure(); + return emitOptionalError(loc, "integer bitwidth is limited to ", + IntegerType::kMaxWidth, " bits"); } return success(); } @@ -213,26 +212,21 @@ VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, StandardTypes::Vector, shape, elementType); } -LogicalResult VectorType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape, - Type elementType) { - if (shape.empty()) { - if (loc) - emitError(*loc, "vector types must have at least one dimension"); - return failure(); - } +LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *context, + ArrayRef<int64_t> shape, + Type elementType) { + if (shape.empty()) + return emitOptionalError(loc, + "vector types must have at least one dimension"); - if (!isValidElementType(elementType)) { - if (loc) - emitError(*loc, "vector elements must be int or float type"); - return failure(); - } + if (!isValidElementType(elementType)) + return emitOptionalError(loc, "vector elements must be int or float type"); + + if (any_of(shape, [](int64_t i) { return i <= 0; })) + return emitOptionalError(loc, + "vector types must have positive constant sizes"); - if (any_of(shape, [](int64_t i) { return i <= 0; })) { - if (loc) - emitError(*loc, "vector types must have positive constant sizes"); - return failure(); - } return success(); } @@ -247,11 +241,8 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } static inline LogicalResult checkTensorElementType(Optional<Location> location, MLIRContext *context, Type elementType) { - if (!TensorType::isValidElementType(elementType)) { - if (location) - emitError(*location, "invalid tensor element type"); - return failure(); - } + if (!TensorType::isValidElementType(elementType)) + return emitOptionalError(location, "invalid tensor element type"); return success(); } @@ -273,14 +264,11 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, } LogicalResult RankedTensorType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape, + Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape, Type elementType) { for (int64_t s : shape) { - if (s < -1) { - if (loc) - emitError(*loc, "invalid tensor dimension size"); - return failure(); - } + if (s < -1) + return emitOptionalError(loc, "invalid tensor dimension size"); } return checkTensorElementType(loc, context, elementType); } @@ -305,7 +293,7 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, } LogicalResult UnrankedTensorType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, Type elementType) { + Optional<Location> loc, MLIRContext *context, Type elementType) { return checkTensorElementType(loc, context, elementType); } @@ -350,19 +338,14 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, auto *context = elementType.getContext(); // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) { - if (location) - emitError(*location, "invalid memref element type"); - return nullptr; - } + if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) + return emitOptionalError(location, "invalid memref element type"), + MemRefType(); for (int64_t s : shape) { // Negative sizes are not allowed except for `-1` that means dynamic size. - if (s < -1) { - if (location) - emitError(*location, "invalid memref size"); - return {}; - } + if (s < -1) + return emitOptionalError(location, "invalid memref size"), MemRefType(); } // Check that the structure of the composition is valid, i.e. that each @@ -631,11 +614,8 @@ ComplexType ComplexType::getChecked(Type elementType, Location location) { /// Verify the construction of an integer type. LogicalResult ComplexType::verifyConstructionInvariants( llvm::Optional<Location> loc, MLIRContext *context, Type elementType) { - if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) { - if (loc) - emitError(*loc, "invalid element type for complex"); - return failure(); - } + if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) + return emitOptionalError(loc, "invalid element type for complex"); return success(); } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index f1a6d8f11c9..23c80c96aad 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -80,13 +80,11 @@ Identifier OpaqueType::getDialectNamespace() const { StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } /// Verify the construction of an opaque type. -LogicalResult OpaqueType::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect, - StringRef typeData) { - if (!Dialect::isValidNamespace(dialect.strref())) { - if (loc) - emitError(*loc) << "invalid dialect namespace '" << dialect << "'"; - return failure(); - } +LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc, + MLIRContext *context, + Identifier dialect, + StringRef typeData) { + if (!Dialect::isValidNamespace(dialect.strref())) + return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); return success(); } |

