//===- AttributeTest.cpp - Attribute unit tests ---------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/Attributes.h" #include "mlir/IR/StandardTypes.h" #include "gtest/gtest.h" using namespace mlir; using namespace mlir::detail; template static void testSplat(Type eltType, const EltTy &splatElt) { VectorType shape = VectorType::get({2, 1}, eltType); // Check that the generated splat is the same for 1 element and N elements. DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); EXPECT_TRUE(splat.isSplat()); auto detectedSplat = DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); EXPECT_EQ(detectedSplat, splat); } namespace { TEST(DenseSplatTest, BoolSplat) { MLIRContext context; IntegerType boolTy = IntegerType::get(1, &context); VectorType shape = VectorType::get({2, 2}, boolTy); // Check that splat is automatically detected for boolean values. /// True. DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); EXPECT_TRUE(trueSplat.isSplat()); /// False. DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); EXPECT_TRUE(falseSplat.isSplat()); EXPECT_NE(falseSplat, trueSplat); /// Detect and handle splat within 8 elements (bool values are bit-packed). /// True. auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); EXPECT_EQ(detectedSplat, trueSplat); /// False. detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); EXPECT_EQ(detectedSplat, falseSplat); } TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56; MLIRContext context; IntegerType boolTy = IntegerType::get(1, &context); VectorType shape = VectorType::get({boolCount}, boolTy); // Check that splat is automatically detected for boolean values. /// True. DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); EXPECT_TRUE(trueSplat.isSplat()); EXPECT_TRUE(falseSplat.isSplat()); /// Detect that the large boolean arrays are properly splatted. /// True. SmallVector trueValues(boolCount, true); auto detectedSplat = DenseElementsAttr::get(shape, trueValues); EXPECT_EQ(detectedSplat, trueSplat); /// False. SmallVector falseValues(boolCount, false); detectedSplat = DenseElementsAttr::get(shape, falseValues); EXPECT_EQ(detectedSplat, falseSplat); } TEST(DenseSplatTest, BoolNonSplat) { MLIRContext context; IntegerType boolTy = IntegerType::get(1, &context); VectorType shape = VectorType::get({6}, boolTy); // Check that we properly handle non-splat values. DenseElementsAttr nonSplat = DenseElementsAttr::get(shape, {false, false, true, false, false, true}); EXPECT_FALSE(nonSplat.isSplat()); } TEST(DenseSplatTest, OddIntSplat) { // Test detecting a splat with an odd(non 8-bit) integer bitwidth. MLIRContext context; constexpr size_t intWidth = 19; IntegerType intTy = IntegerType::get(intWidth, &context); APInt value(intWidth, 10); testSplat(intTy, value); } TEST(DenseSplatTest, Int32Splat) { MLIRContext context; IntegerType intTy = IntegerType::get(32, &context); int value = 64; testSplat(intTy, value); } TEST(DenseSplatTest, IntAttrSplat) { MLIRContext context; IntegerType intTy = IntegerType::get(85, &context); Attribute value = IntegerAttr::get(intTy, 109); testSplat(intTy, value); } TEST(DenseSplatTest, F32Splat) { MLIRContext context; FloatType floatTy = FloatType::getF32(&context); float value = 10.0; testSplat(floatTy, value); } TEST(DenseSplatTest, F64Splat) { MLIRContext context; FloatType floatTy = FloatType::getF64(&context); double value = 10.0; testSplat(floatTy, APFloat(value)); } TEST(DenseSplatTest, FloatAttrSplat) { MLIRContext context; FloatType floatTy = FloatType::getBF16(&context); Attribute value = FloatAttr::get(floatTy, 10.0); testSplat(floatTy, value); } TEST(DenseSplatTest, BF16Splat) { MLIRContext context; FloatType floatTy = FloatType::getBF16(&context); // Note: We currently use double to represent bfloat16. double value = 10.0; testSplat(floatTy, value); } } // end namespace