summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2020-01-09 14:41:49 -0800
committerRiver Riddle <riverriddle@google.com>2020-01-09 14:51:44 -0800
commit68c8b6c4cd117cc962155298f0e1d45056ecc001 (patch)
tree52eed3e8545894ac3e087798316da1eb47389f44
parent58b3dec6c108eb9ae4af2cde5c831743d5605c79 (diff)
downloadbcm5719-llvm-68c8b6c4cd117cc962155298f0e1d45056ecc001.tar.gz
bcm5719-llvm-68c8b6c4cd117cc962155298f0e1d45056ecc001.zip
[mlir] Use getDenseElementBitwidth instead of Type::getElementTypeBitWidth.
Summary: Some data values have a different storage width than the corresponding MLIR type, e.g. bfloat is currently stored as a double. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D72478
-rw-r--r--mlir/lib/IR/Attributes.cpp3
-rw-r--r--mlir/unittests/IR/AttributeTest.cpp10
2 files changed, 12 insertions, 1 deletions
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 3a9c91f6f77..afcc83bb878 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -676,7 +676,8 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
bool isInt) {
// Make sure that the data element size is the same as the type element width.
- if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth())
+ if (getDenseElementBitwidth(type.getElementType()) !=
+ static_cast<size_t>(dataEltSize * CHAR_BIT))
return false;
// Check that the element type is valid.
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 5a1750e1123..066d069c02e 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -135,4 +135,14 @@ TEST(DenseSplatTest, FloatAttrSplat) {
testSplat(floatTy, value);
}
+
+TEST(DenseSplatTest, BF16Splat) {
+ MLIRContext context;
+ FloatType floatTy = FloatType::getBF16(&context);
+ // Note: We currently use double to represent bfloat16.
+ double value = 10.0;
+
+ testSplat(floatTy, value);
+}
+
} // end namespace
OpenPOWER on IntegriCloud