diff options
author | River Riddle <riverriddle@google.com> | 2020-01-09 14:41:49 -0800 |
---|---|---|
committer | River Riddle <riverriddle@google.com> | 2020-01-09 14:51:44 -0800 |
commit | 68c8b6c4cd117cc962155298f0e1d45056ecc001 (patch) | |
tree | 52eed3e8545894ac3e087798316da1eb47389f44 | |
parent | 58b3dec6c108eb9ae4af2cde5c831743d5605c79 (diff) | |
download | bcm5719-llvm-68c8b6c4cd117cc962155298f0e1d45056ecc001.tar.gz bcm5719-llvm-68c8b6c4cd117cc962155298f0e1d45056ecc001.zip |
[mlir] Use getDenseElementBitwidth instead of Type::getElementTypeBitWidth.
Summary: Some data values have a different storage width than the corresponding MLIR type, e.g. bfloat is currently stored as a double.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D72478
-rw-r--r-- | mlir/lib/IR/Attributes.cpp | 3 | ||||
-rw-r--r-- | mlir/unittests/IR/AttributeTest.cpp | 10 |
2 files changed, 12 insertions, 1 deletions
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 3a9c91f6f77..afcc83bb878 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -676,7 +676,8 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt) { // Make sure that the data element size is the same as the type element width. - if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth()) + if (getDenseElementBitwidth(type.getElementType()) != + static_cast<size_t>(dataEltSize * CHAR_BIT)) return false; // Check that the element type is valid. diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 5a1750e1123..066d069c02e 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -135,4 +135,14 @@ TEST(DenseSplatTest, FloatAttrSplat) { testSplat(floatTy, value); } + +TEST(DenseSplatTest, BF16Splat) { + MLIRContext context; + FloatType floatTy = FloatType::getBF16(&context); + // Note: We currently use double to represent bfloat16. + double value = 10.0; + + testSplat(floatTy, value); +} + } // end namespace |