summaryrefslogtreecommitdiffstats
path: root/mlir/unittests/IR/AttributeTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/unittests/IR/AttributeTest.cpp')
-rw-r--r--mlir/unittests/IR/AttributeTest.cpp138
1 files changed, 138 insertions, 0 deletions
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
new file mode 100644
index 00000000000..5a1750e1123
--- /dev/null
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -0,0 +1,138 @@
+//===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+template <typename EltTy>
+static void testSplat(Type eltType, const EltTy &splatElt) {
+ VectorType shape = VectorType::get({2, 1}, eltType);
+
+ // Check that the generated splat is the same for 1 element and N elements.
+ DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
+ EXPECT_TRUE(splat.isSplat());
+
+ auto detectedSplat =
+ DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
+ EXPECT_EQ(detectedSplat, splat);
+}
+
+namespace {
+TEST(DenseSplatTest, BoolSplat) {
+ MLIRContext context;
+ IntegerType boolTy = IntegerType::get(1, &context);
+ VectorType shape = VectorType::get({2, 2}, boolTy);
+
+ // Check that splat is automatically detected for boolean values.
+ /// True.
+ DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
+ EXPECT_TRUE(trueSplat.isSplat());
+ /// False.
+ DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
+ EXPECT_TRUE(falseSplat.isSplat());
+ EXPECT_NE(falseSplat, trueSplat);
+
+ /// Detect and handle splat within 8 elements (bool values are bit-packed).
+ /// True.
+ auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
+ EXPECT_EQ(detectedSplat, trueSplat);
+ /// False.
+ detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
+ EXPECT_EQ(detectedSplat, falseSplat);
+}
+
+TEST(DenseSplatTest, LargeBoolSplat) {
+ constexpr int64_t boolCount = 56;
+
+ MLIRContext context;
+ IntegerType boolTy = IntegerType::get(1, &context);
+ VectorType shape = VectorType::get({boolCount}, boolTy);
+
+ // Check that splat is automatically detected for boolean values.
+ /// True.
+ DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
+ DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
+ EXPECT_TRUE(trueSplat.isSplat());
+ EXPECT_TRUE(falseSplat.isSplat());
+
+ /// Detect that the large boolean arrays are properly splatted.
+ /// True.
+ SmallVector<bool, 64> trueValues(boolCount, true);
+ auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
+ EXPECT_EQ(detectedSplat, trueSplat);
+ /// False.
+ SmallVector<bool, 64> falseValues(boolCount, false);
+ detectedSplat = DenseElementsAttr::get(shape, falseValues);
+ EXPECT_EQ(detectedSplat, falseSplat);
+}
+
+TEST(DenseSplatTest, BoolNonSplat) {
+ MLIRContext context;
+ IntegerType boolTy = IntegerType::get(1, &context);
+ VectorType shape = VectorType::get({6}, boolTy);
+
+ // Check that we properly handle non-splat values.
+ DenseElementsAttr nonSplat =
+ DenseElementsAttr::get(shape, {false, false, true, false, false, true});
+ EXPECT_FALSE(nonSplat.isSplat());
+}
+
+TEST(DenseSplatTest, OddIntSplat) {
+ // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
+ MLIRContext context;
+ constexpr size_t intWidth = 19;
+ IntegerType intTy = IntegerType::get(intWidth, &context);
+ APInt value(intWidth, 10);
+
+ testSplat(intTy, value);
+}
+
+TEST(DenseSplatTest, Int32Splat) {
+ MLIRContext context;
+ IntegerType intTy = IntegerType::get(32, &context);
+ int value = 64;
+
+ testSplat(intTy, value);
+}
+
+TEST(DenseSplatTest, IntAttrSplat) {
+ MLIRContext context;
+ IntegerType intTy = IntegerType::get(85, &context);
+ Attribute value = IntegerAttr::get(intTy, 109);
+
+ testSplat(intTy, value);
+}
+
+TEST(DenseSplatTest, F32Splat) {
+ MLIRContext context;
+ FloatType floatTy = FloatType::getF32(&context);
+ float value = 10.0;
+
+ testSplat(floatTy, value);
+}
+
+TEST(DenseSplatTest, F64Splat) {
+ MLIRContext context;
+ FloatType floatTy = FloatType::getF64(&context);
+ double value = 10.0;
+
+ testSplat(floatTy, APFloat(value));
+}
+
+TEST(DenseSplatTest, FloatAttrSplat) {
+ MLIRContext context;
+ FloatType floatTy = FloatType::getBF16(&context);
+ Attribute value = FloatAttr::get(floatTy, 10.0);
+
+ testSplat(floatTy, value);
+}
+} // end namespace
OpenPOWER on IntegriCloud