summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-04 15:49:09 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-04 15:49:42 -0800
commit2c930f8d9daa04576ac49e14c19dc540ebf823fe (patch)
tree6dcb8515f827477ef49752a5d19818404988e53c /mlir/lib/IR
parentb3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2 (diff)
downloadbcm5719-llvm-2c930f8d9daa04576ac49e14c19dc540ebf823fe.tar.gz
bcm5719-llvm-2c930f8d9daa04576ac49e14c19dc540ebf823fe.zip
Add emitOptional(Error|Warning|Remark) functions to simplify emission with an optional location.
In some situations a diagnostic may optionally be emitted by the presence of a location, e.g. attribute and type verification. These situations currently require extra 'if(loc) emitError(...); return failure()' wrappers that make verification clunky. These new overloads take an optional location and a list of arguments to the diagnostic, and return a LogicalResult. We take the arguments directly and return LogicalResult instead of returning InFlightDiagnostic because we cannot create a valid diagnostic with a null location. This creates an awkward situation where a user may try to treat the, potentially null, diagnostic as a valid one and encounter crashes when attaching notes/etc. Below is an example of how these methods simplify some existing usages: Before: if (loc) emitError(*loc, "this is my diagnostic with argument: ") << 5; return failure(); After: return emitOptionalError(loc, "this is my diagnostic with argument: ", 5); PiperOrigin-RevId: 283853599
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/Attributes.cpp43
-rw-r--r--mlir/lib/IR/StandardTypes.cpp82
-rw-r--r--mlir/lib/IR/Types.cpp14
3 files changed, 56 insertions, 83 deletions
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 5d7a4f08d1e..f2f3d41f980 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -214,35 +214,31 @@ double FloatAttr::getValueAsDouble(APFloat value) {
}
/// Verify construction invariants.
-static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
+static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
Type type) {
- if (!type.isa<FloatType>()) {
- if (loc)
- emitError(*loc, "expected floating point type");
- return failure();
- }
+ if (!type.isa<FloatType>())
+ return emitOptionalError(loc, "expected floating point type");
return success();
}
-LogicalResult FloatAttr::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
+LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *ctx,
+ Type type, double value) {
return verifyFloatTypeInvariants(loc, type);
}
-LogicalResult
-FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
- MLIRContext *ctx, Type type,
- const APFloat &value) {
+LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *ctx,
+ Type type,
+ const APFloat &value) {
// Verify that the type is correct.
if (failed(verifyFloatTypeInvariants(loc, type)))
return failure();
// Verify that the type semantics match that of the value.
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
- if (loc)
- emitError(*loc,
- "FloatAttr type doesn't match the type implied by its value");
- return failure();
+ return emitOptionalError(
+ loc, "FloatAttr type doesn't match the type implied by its value");
}
return success();
}
@@ -330,14 +326,13 @@ Identifier OpaqueAttr::getDialectNamespace() const {
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
/// Verify the construction of an opaque attribute.
-LogicalResult OpaqueAttr::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
- StringRef attrData, Type type) {
- if (!Dialect::isValidNamespace(dialect.strref())) {
- if (loc)
- emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
- return failure();
- }
+LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *context,
+ Identifier dialect,
+ StringRef attrData,
+ Type type) {
+ if (!Dialect::isValidNamespace(dialect.strref()))
+ return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
return success();
}
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 4347856de36..8a4b51f215a 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -61,13 +61,12 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
constexpr unsigned IntegerType::kMaxWidth;
/// Verify the construction of an integer type.
-LogicalResult IntegerType::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *context, unsigned width) {
+LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *context,
+ unsigned width) {
if (width > IntegerType::kMaxWidth) {
- if (loc)
- emitError(*loc) << "integer bitwidth is limited to "
- << IntegerType::kMaxWidth << " bits";
- return failure();
+ return emitOptionalError(loc, "integer bitwidth is limited to ",
+ IntegerType::kMaxWidth, " bits");
}
return success();
}
@@ -213,26 +212,21 @@ VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
StandardTypes::Vector, shape, elementType);
}
-LogicalResult VectorType::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
- Type elementType) {
- if (shape.empty()) {
- if (loc)
- emitError(*loc, "vector types must have at least one dimension");
- return failure();
- }
+LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *context,
+ ArrayRef<int64_t> shape,
+ Type elementType) {
+ if (shape.empty())
+ return emitOptionalError(loc,
+ "vector types must have at least one dimension");
- if (!isValidElementType(elementType)) {
- if (loc)
- emitError(*loc, "vector elements must be int or float type");
- return failure();
- }
+ if (!isValidElementType(elementType))
+ return emitOptionalError(loc, "vector elements must be int or float type");
+
+ if (any_of(shape, [](int64_t i) { return i <= 0; }))
+ return emitOptionalError(loc,
+ "vector types must have positive constant sizes");
- if (any_of(shape, [](int64_t i) { return i <= 0; })) {
- if (loc)
- emitError(*loc, "vector types must have positive constant sizes");
- return failure();
- }
return success();
}
@@ -247,11 +241,8 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
static inline LogicalResult checkTensorElementType(Optional<Location> location,
MLIRContext *context,
Type elementType) {
- if (!TensorType::isValidElementType(elementType)) {
- if (location)
- emitError(*location, "invalid tensor element type");
- return failure();
- }
+ if (!TensorType::isValidElementType(elementType))
+ return emitOptionalError(location, "invalid tensor element type");
return success();
}
@@ -273,14 +264,11 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
}
LogicalResult RankedTensorType::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
+ Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType) {
for (int64_t s : shape) {
- if (s < -1) {
- if (loc)
- emitError(*loc, "invalid tensor dimension size");
- return failure();
- }
+ if (s < -1)
+ return emitOptionalError(loc, "invalid tensor dimension size");
}
return checkTensorElementType(loc, context, elementType);
}
@@ -305,7 +293,7 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
}
LogicalResult UnrankedTensorType::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
+ Optional<Location> loc, MLIRContext *context, Type elementType) {
return checkTensorElementType(loc, context, elementType);
}
@@ -350,19 +338,14 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) {
- if (location)
- emitError(*location, "invalid memref element type");
- return nullptr;
- }
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+ return emitOptionalError(location, "invalid memref element type"),
+ MemRefType();
for (int64_t s : shape) {
// Negative sizes are not allowed except for `-1` that means dynamic size.
- if (s < -1) {
- if (location)
- emitError(*location, "invalid memref size");
- return {};
- }
+ if (s < -1)
+ return emitOptionalError(location, "invalid memref size"), MemRefType();
}
// Check that the structure of the composition is valid, i.e. that each
@@ -631,11 +614,8 @@ ComplexType ComplexType::getChecked(Type elementType, Location location) {
/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
- if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
- if (loc)
- emitError(*loc, "invalid element type for complex");
- return failure();
- }
+ if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
+ return emitOptionalError(loc, "invalid element type for complex");
return success();
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index f1a6d8f11c9..23c80c96aad 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -80,13 +80,11 @@ Identifier OpaqueType::getDialectNamespace() const {
StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
/// Verify the construction of an opaque type.
-LogicalResult OpaqueType::verifyConstructionInvariants(
- llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
- StringRef typeData) {
- if (!Dialect::isValidNamespace(dialect.strref())) {
- if (loc)
- emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
- return failure();
- }
+LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *context,
+ Identifier dialect,
+ StringRef typeData) {
+ if (!Dialect::isValidNamespace(dialect.strref()))
+ return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
return success();
}
OpenPOWER on IntegriCloud