summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/Attributes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/Attributes.cpp')
-rw-r--r--mlir/lib/IR/Attributes.cpp1101
1 files changed, 1101 insertions, 0 deletions
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
new file mode 100644
index 00000000000..3a9c91f6f77
--- /dev/null
+++ b/mlir/lib/IR/Attributes.cpp
@@ -0,0 +1,1101 @@
+//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Attributes.h"
+#include "AttributeDetail.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/Twine.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// AttributeStorage
+//===----------------------------------------------------------------------===//
+
+AttributeStorage::AttributeStorage(Type type)
+ : type(type.getAsOpaquePointer()) {}
+AttributeStorage::AttributeStorage() : type(nullptr) {}
+
+Type AttributeStorage::getType() const {
+ return Type::getFromOpaquePointer(type);
+}
+void AttributeStorage::setType(Type newType) {
+ type = newType.getAsOpaquePointer();
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute
+//===----------------------------------------------------------------------===//
+
+/// Return the type of this attribute.
+Type Attribute::getType() const { return impl->getType(); }
+
+/// Return the context this attribute belongs to.
+MLIRContext *Attribute::getContext() const { return getType().getContext(); }
+
+/// Get the dialect this attribute is registered to.
+Dialect &Attribute::getDialect() const { return impl->getDialect(); }
+
+//===----------------------------------------------------------------------===//
+// AffineMapAttr
+//===----------------------------------------------------------------------===//
+
+AffineMapAttr AffineMapAttr::get(AffineMap value) {
+ return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
+}
+
+AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
+ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+ return Base::get(context, StandardAttributes::Array, value);
+}
+
+ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// BoolAttr
+//===----------------------------------------------------------------------===//
+
+bool BoolAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// DictionaryAttr
+//===----------------------------------------------------------------------===//
+
+/// Perform a three-way comparison between the names of the specified
+/// NamedAttributes.
+static int compareNamedAttributes(const NamedAttribute *lhs,
+ const NamedAttribute *rhs) {
+ return lhs->first.strref().compare(rhs->first.strref());
+}
+
+DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
+ MLIRContext *context) {
+ assert(llvm::all_of(value,
+ [](const NamedAttribute &attr) { return attr.second; }) &&
+ "value cannot have null entries");
+
+ // We need to sort the element list to canonicalize it, but we also don't want
+ // to do a ton of work in the super common case where the element list is
+ // already sorted.
+ SmallVector<NamedAttribute, 8> storage;
+ switch (value.size()) {
+ case 0:
+ break;
+ case 1:
+ // A single element is already sorted.
+ break;
+ case 2:
+ assert(value[0].first != value[1].first &&
+ "DictionaryAttr element names must be unique");
+
+ // Don't invoke a general sort for two element case.
+ if (value[0].first.strref() > value[1].first.strref()) {
+ storage.push_back(value[1]);
+ storage.push_back(value[0]);
+ value = storage;
+ }
+ break;
+ default:
+ // Check to see they are sorted already.
+ bool isSorted = true;
+ for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
+ if (value[i].first.strref() > value[i + 1].first.strref()) {
+ isSorted = false;
+ break;
+ }
+ }
+ // If not, do a general sort.
+ if (!isSorted) {
+ storage.append(value.begin(), value.end());
+ llvm::array_pod_sort(storage.begin(), storage.end(),
+ compareNamedAttributes);
+ value = storage;
+ }
+
+ // Ensure that the attribute elements are unique.
+ assert(std::adjacent_find(value.begin(), value.end(),
+ [](NamedAttribute l, NamedAttribute r) {
+ return l.first == r.first;
+ }) == value.end() &&
+ "DictionaryAttr element names must be unique");
+ }
+
+ return Base::get(context, StandardAttributes::Dictionary, value);
+}
+
+ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
+ return getImpl()->getElements();
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute DictionaryAttr::get(StringRef name) const {
+ ArrayRef<NamedAttribute> values = getValue();
+ auto compare = [](NamedAttribute attr, StringRef name) {
+ return attr.first.strref() < name;
+ };
+ auto it = llvm::lower_bound(values, name, compare);
+ return it != values.end() && it->first.is(name) ? it->second : Attribute();
+}
+Attribute DictionaryAttr::get(Identifier name) const {
+ for (auto elt : getValue())
+ if (elt.first == name)
+ return elt.second;
+ return nullptr;
+}
+
+DictionaryAttr::iterator DictionaryAttr::begin() const {
+ return getValue().begin();
+}
+DictionaryAttr::iterator DictionaryAttr::end() const {
+ return getValue().end();
+}
+size_t DictionaryAttr::size() const { return getValue().size(); }
+
+//===----------------------------------------------------------------------===//
+// FloatAttr
+//===----------------------------------------------------------------------===//
+
+FloatAttr FloatAttr::get(Type type, double value) {
+ return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
+ return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
+ type, value);
+}
+
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+ return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
+ return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
+ type, value);
+}
+
+APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
+
+double FloatAttr::getValueAsDouble() const {
+ return getValueAsDouble(getValue());
+}
+double FloatAttr::getValueAsDouble(APFloat value) {
+ if (&value.getSemantics() != &APFloat::IEEEdouble()) {
+ bool losesInfo = false;
+ value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ }
+ return value.convertToDouble();
+}
+
+/// Verify construction invariants.
+static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
+ Type type) {
+ if (!type.isa<FloatType>())
+ return emitOptionalError(loc, "expected floating point type");
+ return success();
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *ctx,
+ Type type, double value) {
+ return verifyFloatTypeInvariants(loc, type);
+}
+
+LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *ctx,
+ Type type,
+ const APFloat &value) {
+ // Verify that the type is correct.
+ if (failed(verifyFloatTypeInvariants(loc, type)))
+ return failure();
+
+ // Verify that the type semantics match that of the value.
+ if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
+ return emitOptionalError(
+ loc, "FloatAttr type doesn't match the type implied by its value");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
+ .cast<FlatSymbolRefAttr>();
+}
+
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences,
+ MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+}
+
+StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
+
+StringRef SymbolRefAttr::getLeafReference() const {
+ ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
+ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+}
+
+ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
+ return getImpl()->getNestedRefs();
+}
+
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
+IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
+ return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
+}
+
+IntegerAttr IntegerAttr::get(Type type, int64_t value) {
+ // This uses 64 bit APInts by default for index type.
+ if (type.isIndex())
+ return get(type, APInt(64, value));
+
+ auto intType = type.cast<IntegerType>();
+ return get(type, APInt(intType.getWidth(), value));
+}
+
+APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
+
+int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
+
+//===----------------------------------------------------------------------===//
+// IntegerSetAttr
+//===----------------------------------------------------------------------===//
+
+IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
+ return Base::get(value.getConstraint(0).getContext(),
+ StandardAttributes::IntegerSet, value);
+}
+
+IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
+ MLIRContext *context) {
+ return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
+ type);
+}
+
+OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
+ Type type, Location location) {
+ return Base::getChecked(location, type.getContext(),
+ StandardAttributes::Opaque, dialect, attrData, type);
+}
+
+/// Returns the dialect namespace of the opaque attribute.
+Identifier OpaqueAttr::getDialectNamespace() const {
+ return getImpl()->dialectNamespace;
+}
+
+/// Returns the raw attribute data of the opaque attribute.
+StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
+
+/// Verify the construction of an opaque attribute.
+LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
+ MLIRContext *context,
+ Identifier dialect,
+ StringRef attrData,
+ Type type) {
+ if (!Dialect::isValidNamespace(dialect.strref()))
+ return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
+StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+ return get(bytes, NoneType::get(context));
+}
+
+/// Get an instance of a StringAttr with the given string and Type.
+StringAttr StringAttr::get(StringRef bytes, Type type) {
+ return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
+}
+
+StringRef StringAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// TypeAttr
+//===----------------------------------------------------------------------===//
+
+TypeAttr TypeAttr::get(Type value) {
+ return Base::get(value.getContext(), StandardAttributes::Type, value);
+}
+
+Type TypeAttr::getValue() const { return getImpl()->value; }
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType ElementsAttr::getType() const {
+ return Attribute::getType().cast<ShapedType>();
+}
+
+/// Returns the number of elements held by this attribute.
+int64_t ElementsAttr::getNumElements() const {
+ return getType().getNumElements();
+}
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+ switch (getKind()) {
+ case StandardAttributes::DenseElements:
+ return cast<DenseElementsAttr>().getValue(index);
+ case StandardAttributes::OpaqueElements:
+ return cast<OpaqueElementsAttr>().getValue(index);
+ case StandardAttributes::SparseElements:
+ return cast<SparseElementsAttr>().getValue(index);
+ default:
+ llvm_unreachable("unknown ElementsAttr kind");
+ }
+}
+
+/// Return if the given 'index' refers to a valid element in this attribute.
+bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
+ auto type = getType();
+
+ // Verify that the rank of the indices matches the held type.
+ auto rank = type.getRank();
+ if (rank != static_cast<int64_t>(index.size()))
+ return false;
+
+ // Verify that all of the indices are within the shape dimensions.
+ auto shape = type.getShape();
+ return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
+ return static_cast<int64_t>(index[i]) < shape[i];
+ });
+}
+
+ElementsAttr
+ElementsAttr::mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const {
+ switch (getKind()) {
+ case StandardAttributes::DenseElements:
+ return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+ default:
+ llvm_unreachable("unsupported ElementsAttr subtype");
+ }
+}
+
+ElementsAttr
+ElementsAttr::mapValues(Type newElementType,
+ function_ref<APInt(const APFloat &)> mapping) const {
+ switch (getKind()) {
+ case StandardAttributes::DenseElements:
+ return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+ default:
+ llvm_unreachable("unsupported ElementsAttr subtype");
+ }
+}
+
+/// Returns the 1 dimensional flattened row-major index from the given
+/// multi-dimensional index.
+uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
+ assert(isValidIndex(index) && "expected valid multi-dimensional index");
+ auto type = getType();
+
+ // Reduce the provided multidimensional index into a flattended 1D row-major
+ // index.
+ auto rank = type.getRank();
+ auto shape = type.getShape();
+ uint64_t valueIndex = 0;
+ uint64_t dimMultiplier = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ valueIndex += index[i] * dimMultiplier;
+ dimMultiplier *= shape[i];
+ }
+ return valueIndex;
+}
+
+//===----------------------------------------------------------------------===//
+// DenseElementAttr Utilities
+//===----------------------------------------------------------------------===//
+
+static size_t getDenseElementBitwidth(Type eltType) {
+ // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+ // with double semantics.
+ return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
+}
+
+/// Get the bitwidth of a dense element type within the buffer.
+/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
+static size_t getDenseElementStorageWidth(size_t origWidth) {
+ return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
+}
+
+/// Set a bit to a specific value.
+static void setBit(char *rawData, size_t bitPos, bool value) {
+ if (value)
+ rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
+ else
+ rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
+}
+
+/// Return the value of the specified bit.
+static bool getBit(const char *rawData, size_t bitPos) {
+ return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
+}
+
+/// Writes value to the bit position `bitPos` in array `rawData`.
+static void writeBits(char *rawData, size_t bitPos, APInt value) {
+ size_t bitWidth = value.getBitWidth();
+
+ // If the bitwidth is 1 we just toggle the specific bit.
+ if (bitWidth == 1)
+ return setBit(rawData, bitPos, value.isOneValue());
+
+ // Otherwise, the bit position is guaranteed to be byte aligned.
+ assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+ std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
+ llvm::divideCeil(bitWidth, CHAR_BIT),
+ rawData + (bitPos / CHAR_BIT));
+}
+
+/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
+/// `rawData`.
+static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
+ // Handle a boolean bit position.
+ if (bitWidth == 1)
+ return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
+
+ // Otherwise, the bit position must be 8-bit aligned.
+ assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+ APInt result(bitWidth, 0);
+ std::copy_n(
+ rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
+ const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
+ return result;
+}
+
+/// Returns if 'values' corresponds to a splat, i.e. one element, or has the
+/// same element count as 'type'.
+template <typename Values>
+static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
+ return (values.size() == 1) ||
+ (type.getNumElements() == static_cast<int64_t>(values.size()));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseElementAttr Iterators
+//===----------------------------------------------------------------------===//
+
+/// Constructs a new iterator.
+DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
+ DenseElementsAttr attr, size_t index)
+ : indexed_accessor_iterator<AttributeElementIterator, const void *,
+ Attribute, Attribute, Attribute>(
+ attr.getAsOpaquePointer(), index) {}
+
+/// Accesses the Attribute value at this iterator position.
+Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
+ auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
+ Type eltTy = owner.getType().getElementType();
+ if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
+ if (intEltTy.getWidth() == 1)
+ return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
+ owner.getContext());
+ return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
+ }
+ if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
+ IntElementIterator intIt(owner, index);
+ FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
+ return FloatAttr::get(eltTy, *floatIt);
+ }
+ llvm_unreachable("unexpected element type");
+}
+
+/// Constructs a new iterator.
+DenseElementsAttr::BoolElementIterator::BoolElementIterator(
+ DenseElementsAttr attr, size_t dataIndex)
+ : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
+ attr.getRawData().data(), attr.isSplat(), dataIndex) {}
+
+/// Accesses the bool value at this iterator position.
+bool DenseElementsAttr::BoolElementIterator::operator*() const {
+ return getBit(getData(), getDataIndex());
+}
+
+/// Constructs a new iterator.
+DenseElementsAttr::IntElementIterator::IntElementIterator(
+ DenseElementsAttr attr, size_t dataIndex)
+ : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
+ attr.getRawData().data(), attr.isSplat(), dataIndex),
+ bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
+
+/// Accesses the raw APInt value at this iterator position.
+APInt DenseElementsAttr::IntElementIterator::operator*() const {
+ return readBits(getData(),
+ getDataIndex() * getDenseElementStorageWidth(bitWidth),
+ bitWidth);
+}
+
+DenseElementsAttr::FloatElementIterator::FloatElementIterator(
+ const llvm::fltSemantics &smt, IntElementIterator it)
+ : llvm::mapped_iterator<IntElementIterator,
+ std::function<APFloat(const APInt &)>>(
+ it, [&](const APInt &val) { return APFloat(smt, val); }) {}
+
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<Attribute> values) {
+ assert(type.getElementType().isIntOrFloat() &&
+ "expected int or float element type");
+ assert(hasSameElementsOrSplat(type, values));
+
+ auto eltType = type.getElementType();
+ size_t bitWidth = getDenseElementBitwidth(eltType);
+ size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+ // Compress the attribute values into a character buffer.
+ SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+ values.size());
+ APInt intVal;
+ for (unsigned i = 0, e = values.size(); i < e; ++i) {
+ assert(eltType == values[i].getType() &&
+ "expected attribute value to have element type");
+
+ switch (eltType.getKind()) {
+ case StandardTypes::BF16:
+ case StandardTypes::F16:
+ case StandardTypes::F32:
+ case StandardTypes::F64:
+ intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
+ break;
+ case StandardTypes::Integer:
+ intVal = values[i].isa<BoolAttr>()
+ ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
+ : values[i].cast<IntegerAttr>().getValue();
+ break;
+ default:
+ llvm_unreachable("unexpected element type");
+ }
+ assert(intVal.getBitWidth() == bitWidth &&
+ "expected value to have same bitwidth as element type");
+ writeBits(data.data(), i * storageBitWidth, intVal);
+ }
+ return getRaw(type, data, /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<bool> values) {
+ assert(hasSameElementsOrSplat(type, values));
+ assert(type.getElementType().isInteger(1));
+
+ std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
+ for (int i = 0, e = values.size(); i != e; ++i)
+ setBit(buff.data(), i, values[i]);
+ return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
+}
+
+/// Constructs a dense integer elements attribute from an array of APInt
+/// values. Each APInt value is expected to have the same bitwidth as the
+/// element type of 'type'.
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<APInt> values) {
+ assert(type.getElementType().isa<IntegerType>());
+ return getRaw(type, values);
+}
+
+// Constructs a dense float elements attribute from an array of APFloat
+// values. Each APFloat value is expected to have the same bitwidth as the
+// element type of 'type'.
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<APFloat> values) {
+ assert(type.getElementType().isa<FloatType>());
+
+ // Convert the APFloat values to APInt and create a dense elements attribute.
+ std::vector<APInt> intValues(values.size());
+ for (unsigned i = 0, e = values.size(); i != e; ++i)
+ intValues[i] = values[i].bitcastToAPInt();
+ return getRaw(type, intValues);
+}
+
+// Constructs a dense elements attribute from an array of raw APInt values.
+// Each APInt value is expected to have the same bitwidth as the element type
+// of 'type'.
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+ ArrayRef<APInt> values) {
+ assert(hasSameElementsOrSplat(type, values));
+
+ size_t bitWidth = getDenseElementBitwidth(type.getElementType());
+ size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+ std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+ values.size());
+ for (unsigned i = 0, e = values.size(); i != e; ++i) {
+ assert(values[i].getBitWidth() == bitWidth);
+ writeBits(elementData.data(), i * storageBitWidth, values[i]);
+ }
+ return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+ ArrayRef<char> data, bool isSplat) {
+ assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+ "type must be ranked tensor or vector");
+ assert(type.hasStaticShape() && "type must have static shape");
+ return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
+ data, isSplat);
+}
+
+/// Check the information for a c++ data type, check if this type is valid for
+/// the current attribute. This method is used to verify specific type
+/// invariants that the templatized 'getValues' method cannot.
+static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
+ bool isInt) {
+ // Make sure that the data element size is the same as the type element width.
+ if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth())
+ return false;
+
+ // Check that the element type is valid.
+ return isInt ? type.getElementType().isa<IntegerType>()
+ : type.getElementType().isa<FloatType>();
+}
+
+/// Overload of the 'getRaw' method that asserts that the given type is of
+/// integer type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt) {
+ assert(::isValidIntOrFloat(type, dataEltSize, isInt));
+
+ int64_t numElements = data.size() / dataEltSize;
+ assert(numElements == 1 || numElements == type.getNumElements());
+ return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
+/// A method used to verify specific type invariants that the templatized 'get'
+/// method cannot.
+bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
+ bool isInt) const {
+ return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
+}
+
+/// Return the raw storage data held by this attribute.
+ArrayRef<char> DenseElementsAttr::getRawData() const {
+ return static_cast<ImplType *>(impl)->data;
+}
+
+/// Returns if this attribute corresponds to a splat, i.e. if all element
+/// values are the same.
+bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
+
+/// Return the held element values as a range of Attributes.
+auto DenseElementsAttr::getAttributeValues() const
+ -> llvm::iterator_range<AttributeElementIterator> {
+ return {attr_value_begin(), attr_value_end()};
+}
+auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
+ return AttributeElementIterator(*this, 0);
+}
+auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
+ return AttributeElementIterator(*this, getNumElements());
+}
+
+/// Return the held element values as a range of bool. The element type of
+/// this attribute must be of integer type of bitwidth 1.
+auto DenseElementsAttr::getBoolValues() const
+ -> llvm::iterator_range<BoolElementIterator> {
+ auto eltType = getType().getElementType().dyn_cast<IntegerType>();
+ assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
+ (void)eltType;
+ return {BoolElementIterator(*this, 0),
+ BoolElementIterator(*this, getNumElements())};
+}
+
+/// Return the held element values as a range of APInts. The element type of
+/// this attribute must be of integer type.
+auto DenseElementsAttr::getIntValues() const
+ -> llvm::iterator_range<IntElementIterator> {
+ assert(getType().getElementType().isa<IntegerType>() &&
+ "expected integer type");
+ return {raw_int_begin(), raw_int_end()};
+}
+auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
+ assert(getType().getElementType().isa<IntegerType>() &&
+ "expected integer type");
+ return raw_int_begin();
+}
+auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
+ assert(getType().getElementType().isa<IntegerType>() &&
+ "expected integer type");
+ return raw_int_end();
+}
+
+/// Return the held element values as a range of APFloat. The element type of
+/// this attribute must be of float type.
+auto DenseElementsAttr::getFloatValues() const
+ -> llvm::iterator_range<FloatElementIterator> {
+ auto elementType = getType().getElementType().cast<FloatType>();
+ assert(elementType.isa<FloatType>() && "expected float type");
+ const auto &elementSemantics = elementType.getFloatSemantics();
+ return {FloatElementIterator(elementSemantics, raw_int_begin()),
+ FloatElementIterator(elementSemantics, raw_int_end())};
+}
+auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
+ return getFloatValues().begin();
+}
+auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
+ return getFloatValues().end();
+}
+
+/// Return a new DenseElementsAttr that has the same data as the current
+/// attribute, but has been reshaped to 'newType'. The new type must have the
+/// same total number of elements as well as element type.
+DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+ ShapedType curType = getType();
+ if (curType == newType)
+ return *this;
+
+ (void)curType;
+ assert(newType.getElementType() == curType.getElementType() &&
+ "expected the same element type");
+ assert(newType.getNumElements() == curType.getNumElements() &&
+ "expected the same number of elements");
+ return getRaw(newType, getRawData(), isSplat());
+}
+
+DenseElementsAttr
+DenseElementsAttr::mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const {
+ return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+}
+
+DenseElementsAttr DenseElementsAttr::mapValues(
+ Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
+ return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+}
+
+//===----------------------------------------------------------------------===//
+// DenseFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+template <typename Fn, typename Attr>
+static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
+ Type newElementType,
+ llvm::SmallVectorImpl<char> &data) {
+ size_t bitWidth = getDenseElementBitwidth(newElementType);
+ size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+ ShapedType newArrayType;
+ if (inType.isa<RankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<UnrankedTensorType>())
+ newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
+ else if (inType.isa<VectorType>())
+ newArrayType = VectorType::get(inType.getShape(), newElementType);
+ else
+ assert(newArrayType && "Unhandled tensor type");
+
+ size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
+ data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
+
+ // Functor used to process a single element value of the attribute.
+ auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
+ auto newInt = mapping(value);
+ assert(newInt.getBitWidth() == bitWidth);
+ writeBits(data.data(), index * storageBitWidth, newInt);
+ };
+
+ // Check for the splat case.
+ if (attr.isSplat()) {
+ processElt(*attr.begin(), /*index=*/0);
+ return newArrayType;
+ }
+
+ // Otherwise, process all of the element values.
+ uint64_t elementIdx = 0;
+ for (auto value : attr)
+ processElt(value, elementIdx++);
+ return newArrayType;
+}
+
+DenseElementsAttr DenseFPElementsAttr::mapValues(
+ Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
+ llvm::SmallVector<char, 8> elementData;
+ auto newArrayType =
+ mappingHelper(mapping, *this, getType(), newElementType, elementData);
+
+ return getRaw(newArrayType, elementData, isSplat());
+}
+
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseFPElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseElementsAttr>() &&
+ attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseElementsAttr DenseIntElementsAttr::mapValues(
+ Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
+ llvm::SmallVector<char, 8> elementData;
+ auto newArrayType =
+ mappingHelper(mapping, *this, getType(), newElementType, elementData);
+
+ return getRaw(newArrayType, elementData, isSplat());
+}
+
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseIntElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseElementsAttr>() &&
+ attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
+}
+
+//===----------------------------------------------------------------------===//
+// OpaqueElementsAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
+ StringRef bytes) {
+ assert(TensorType::isValidElementType(type.getElementType()) &&
+ "Input element type should be a valid tensor element type");
+ return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
+ dialect, bytes);
+}
+
+StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
+
+/// Return the value at the given index. If index does not refer to a valid
+/// element, then a null attribute is returned.
+Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+ assert(isValidIndex(index) && "expected valid multi-dimensional index");
+ if (Dialect *dialect = getDialect())
+ return dialect->extractElementHook(*this, index);
+ return Attribute();
+}
+
+Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
+
+bool OpaqueElementsAttr::decode(ElementsAttr &result) {
+ if (auto *d = getDialect())
+ return d->decodeHook(*this, result);
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+SparseElementsAttr SparseElementsAttr::get(ShapedType type,
+ DenseElementsAttr indices,
+ DenseElementsAttr values) {
+ assert(indices.getType().getElementType().isInteger(64) &&
+ "expected sparse indices to be 64-bit integer values");
+ assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+ "type must be ranked tensor or vector");
+ assert(type.hasStaticShape() && "type must have static shape");
+ return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
+ indices.cast<DenseIntElementsAttr>(), values);
+}
+
+DenseIntElementsAttr SparseElementsAttr::getIndices() const {
+ return getImpl()->indices;
+}
+
+DenseElementsAttr SparseElementsAttr::getValues() const {
+ return getImpl()->values;
+}
+
+/// Return the value of the element at the given index.
+Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
+ assert(isValidIndex(index) && "expected valid multi-dimensional index");
+ auto type = getType();
+
+ // The sparse indices are 64-bit integers, so we can reinterpret the raw data
+ // as a 1-D index array.
+ auto sparseIndices = getIndices();
+ auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
+
+ // Check to see if the indices are a splat.
+ if (sparseIndices.isSplat()) {
+ // If the index is also not a splat of the index value, we know that the
+ // value is zero.
+ auto splatIndex = *sparseIndexValues.begin();
+ if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
+ return getZeroAttr();
+
+ // If the indices are a splat, we also expect the values to be a splat.
+ assert(getValues().isSplat() && "expected splat values");
+ return getValues().getSplatValue();
+ }
+
+ // Build a mapping between known indices and the offset of the stored element.
+ llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
+ auto numSparseIndices = sparseIndices.getType().getDimSize(0);
+ size_t rank = type.getRank();
+ for (size_t i = 0, e = numSparseIndices; i != e; ++i)
+ mappedIndices.try_emplace(
+ {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
+
+ // Look for the provided index key within the mapped indices. If the provided
+ // index is not found, then return a zero attribute.
+ auto it = mappedIndices.find(index);
+ if (it == mappedIndices.end())
+ return getZeroAttr();
+
+ // Otherwise, return the held sparse value element.
+ return getValues().getValue(it->second);
+}
+
+/// Get a zero APFloat for the given sparse attribute.
+APFloat SparseElementsAttr::getZeroAPFloat() const {
+ auto eltType = getType().getElementType().cast<FloatType>();
+ return APFloat(eltType.getFloatSemantics());
+}
+
+/// Get a zero APInt for the given sparse attribute.
+APInt SparseElementsAttr::getZeroAPInt() const {
+ auto eltType = getType().getElementType().cast<IntegerType>();
+ return APInt::getNullValue(eltType.getWidth());
+}
+
+/// Get a zero attribute for the given attribute type.
+Attribute SparseElementsAttr::getZeroAttr() const {
+ auto eltType = getType().getElementType();
+
+ // Handle floating point elements.
+ if (eltType.isa<FloatType>())
+ return FloatAttr::get(eltType, 0);
+
+ // Otherwise, this is an integer.
+ auto intEltTy = eltType.cast<IntegerType>();
+ if (intEltTy.getWidth() == 1)
+ return BoolAttr::get(false, eltType.getContext());
+ return IntegerAttr::get(eltType, 0);
+}
+
+/// Flatten, and return, all of the sparse indices in this attribute in
+/// row-major order.
+std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
+ std::vector<ptrdiff_t> flatSparseIndices;
+
+ // The sparse indices are 64-bit integers, so we can reinterpret the raw data
+ // as a 1-D index array.
+ auto sparseIndices = getIndices();
+ auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
+ if (sparseIndices.isSplat()) {
+ SmallVector<uint64_t, 8> indices(getType().getRank(),
+ *sparseIndexValues.begin());
+ flatSparseIndices.push_back(getFlattenedIndex(indices));
+ return flatSparseIndices;
+ }
+
+ // Otherwise, reinterpret each index as an ArrayRef when flattening.
+ auto numSparseIndices = sparseIndices.getType().getDimSize(0);
+ size_t rank = getType().getRank();
+ for (size_t i = 0, e = numSparseIndices; i != e; ++i)
+ flatSparseIndices.push_back(getFlattenedIndex(
+ {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
+ return flatSparseIndices;
+}
+
+//===----------------------------------------------------------------------===//
+// NamedAttributeList
+//===----------------------------------------------------------------------===//
+
+NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
+ setAttrs(attributes);
+}
+
+ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
+ return attrs ? attrs.getValue() : llvm::None;
+}
+
+/// Replace the held attributes with ones provided in 'newAttrs'.
+void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
+ // Don't create an attribute list if there are no attributes.
+ if (attributes.empty())
+ attrs = nullptr;
+ else
+ attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute NamedAttributeList::get(StringRef name) const {
+ return attrs ? attrs.get(name) : nullptr;
+}
+
+/// Return the specified attribute if present, null otherwise.
+Attribute NamedAttributeList::get(Identifier name) const {
+ return attrs ? attrs.get(name) : nullptr;
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+void NamedAttributeList::set(Identifier name, Attribute value) {
+ assert(value && "attributes may never be null");
+
+ // If we already have this attribute, replace it.
+ auto origAttrs = getAttrs();
+ SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
+ for (auto &elt : newAttrs)
+ if (elt.first == name) {
+ elt.second = value;
+ attrs = DictionaryAttr::get(newAttrs, value.getContext());
+ return;
+ }
+
+ // Otherwise, add it.
+ newAttrs.push_back({name, value});
+ attrs = DictionaryAttr::get(newAttrs, value.getContext());
+}
+
+/// Remove the attribute with the specified name if it exists. The return
+/// value indicates whether the attribute was present or not.
+auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
+ auto origAttrs = getAttrs();
+ for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
+ if (origAttrs[i].first == name) {
+ // Handle the simple case of removing the only attribute in the list.
+ if (e == 1) {
+ attrs = nullptr;
+ return RemoveResult::Removed;
+ }
+
+ SmallVector<NamedAttribute, 8> newAttrs;
+ newAttrs.reserve(origAttrs.size() - 1);
+ newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
+ newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
+ attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
+ return RemoveResult::Removed;
+ }
+ }
+ return RemoveResult::NotFound;
+}
OpenPOWER on IntegriCloud