diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/IR/AttributeDetail.h | 302 | ||||
-rw-r--r-- | mlir/lib/IR/Attributes.cpp | 214 |
2 files changed, 248 insertions, 268 deletions
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index f164f505f33..0fe07a97916 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -24,7 +24,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/Function.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" @@ -36,28 +35,43 @@ namespace mlir { namespace detail { -/// Opaque Attribute Storage and Uniquing. -struct OpaqueAttributeStorage : public AttributeStorage { - OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData) - : dialectNamespace(dialectNamespace), attrData(attrData) {} +// An attribute representing a reference to an affine map. +struct AffineMapAttributeStorage : public AttributeStorage { + using KeyTy = AffineMap; - /// The hash key used for uniquing. - using KeyTy = std::pair<Identifier, StringRef>; - bool operator==(const KeyTy &key) const { - return key == KeyTy(dialectNamespace, attrData); - } + AffineMapAttributeStorage(AffineMap value) + : AttributeStorage(IndexType::get(value.getContext())), value(value) {} - static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate<OpaqueAttributeStorage>()) - OpaqueAttributeStorage(key.first, allocator.copyInto(key.second)); + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static AffineMapAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate<AffineMapAttributeStorage>()) + AffineMapAttributeStorage(key); } - // The dialect namespace. - Identifier dialectNamespace; + AffineMap value; +}; - // The parser attribute data for this opaque attribute. - StringRef attrData; +/// An attribute representing an array of other attributes. +struct ArrayAttributeStorage : public AttributeStorage { + using KeyTy = ArrayRef<Attribute>; + + ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate<ArrayAttributeStorage>()) + ArrayAttributeStorage(allocator.copyInto(key)); + } + + ArrayRef<Attribute> value; }; /// An attribute representing a boolean value. @@ -82,51 +96,51 @@ struct BoolAttributeStorage : public AttributeStorage { bool value; }; -/// An attribute representing a integral value. -struct IntegerAttributeStorage final +/// An attribute representing a dictionary of sorted named attributes. +struct DictionaryAttributeStorage final : public AttributeStorage, - public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> { - using KeyTy = std::pair<Type, APInt>; + private llvm::TrailingObjects<DictionaryAttributeStorage, + NamedAttribute> { + using KeyTy = ArrayRef<NamedAttribute>; - IntegerAttributeStorage(Type type, size_t numObjects) - : AttributeStorage(type), numObjects(numObjects) { - assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type"); - } + /// Given a list of NamedAttribute's, canonicalize the list (sorting + /// by name) and return the unique'd result. + static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs); - /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { - return key == KeyTy(getType(), getValue()); - } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(key.first, llvm::hash_value(key.second)); - } + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == getElements(); } /// Construct a new storage instance. - static IntegerAttributeStorage * + static DictionaryAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key) { - Type type; - APInt value; - std::tie(type, value) = key; + auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>( + key.size()); + auto rawMem = allocator.allocate(size, alignof(NamedAttribute)); - auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords()); - auto size = - IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size()); - auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage)); - auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), - result->getTrailingObjects<uint64_t>()); + // Initialize the storage and trailing attribute list. + auto result = ::new (rawMem) DictionaryAttributeStorage(key.size()); + std::uninitialized_copy(key.begin(), key.end(), + result->getTrailingObjects<NamedAttribute>()); return result; } - /// Returns an APInt representing the stored value. - APInt getValue() const { - if (getType().isIndex()) - return APInt(64, {getTrailingObjects<uint64_t>(), numObjects}); - return APInt(getType().getIntOrFloatBitWidth(), - {getTrailingObjects<uint64_t>(), numObjects}); + /// Return the elements of this dictionary attribute. + ArrayRef<NamedAttribute> getElements() const { + return {getTrailingObjects<NamedAttribute>(), numElements}; } - size_t numObjects; +private: + friend class llvm::TrailingObjects<DictionaryAttributeStorage, + NamedAttribute>; + + // This is used by the llvm::TrailingObjects base class. + size_t numTrailingObjects(OverloadToken<NamedAttribute>) const { + return numElements; + } + DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {} + + /// This is the number of attributes. + const unsigned numElements; }; /// An attribute representing a floating point value. @@ -191,128 +205,113 @@ struct FloatAttributeStorage final size_t numObjects; }; -/// An attribute representing a string value. -struct StringAttributeStorage : public AttributeStorage { - using KeyTy = StringRef; - - StringAttributeStorage(StringRef value) : value(value) {} - - /// Key equality function. - bool operator==(const KeyTy &key) const { return key == value; } +/// An attribute representing a integral value. +struct IntegerAttributeStorage final + : public AttributeStorage, + public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> { + using KeyTy = std::pair<Type, APInt>; - /// Construct a new storage instance. - static StringAttributeStorage *construct(AttributeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate<StringAttributeStorage>()) - StringAttributeStorage(allocator.copyInto(key)); + IntegerAttributeStorage(Type type, size_t numObjects) + : AttributeStorage(type), numObjects(numObjects) { + assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type"); } - StringRef value; -}; - -/// An attribute representing an array of other attributes. -struct ArrayAttributeStorage : public AttributeStorage { - using KeyTy = ArrayRef<Attribute>; - - ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {} - - /// Key equality function. - bool operator==(const KeyTy &key) const { return key == value; } - - /// Construct a new storage instance. - static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate<ArrayAttributeStorage>()) - ArrayAttributeStorage(allocator.copyInto(key)); + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { + return key == KeyTy(getType(), getValue()); + } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, llvm::hash_value(key.second)); } - - ArrayRef<Attribute> value; -}; - -/// An attribute representing a dictionary of sorted named attributes. -struct DictionaryAttributeStorage final - : public AttributeStorage, - private llvm::TrailingObjects<DictionaryAttributeStorage, - NamedAttribute> { - using KeyTy = ArrayRef<NamedAttribute>; - - /// Given a list of NamedAttribute's, canonicalize the list (sorting - /// by name) and return the unique'd result. - static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs); - - /// Key equality function. - bool operator==(const KeyTy &key) const { return key == getElements(); } /// Construct a new storage instance. - static DictionaryAttributeStorage * + static IntegerAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key) { - auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>( - key.size()); - auto rawMem = allocator.allocate(size, alignof(NamedAttribute)); + Type type; + APInt value; + std::tie(type, value) = key; - // Initialize the storage and trailing attribute list. - auto result = ::new (rawMem) DictionaryAttributeStorage(key.size()); - std::uninitialized_copy(key.begin(), key.end(), - result->getTrailingObjects<NamedAttribute>()); + auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords()); + auto size = + IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size()); + auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage)); + auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), + result->getTrailingObjects<uint64_t>()); return result; } - /// Return the elements of this dictionary attribute. - ArrayRef<NamedAttribute> getElements() const { - return {getTrailingObjects<NamedAttribute>(), numElements}; - } - -private: - friend class llvm::TrailingObjects<DictionaryAttributeStorage, - NamedAttribute>; - - // This is used by the llvm::TrailingObjects base class. - size_t numTrailingObjects(OverloadToken<NamedAttribute>) const { - return numElements; + /// Returns an APInt representing the stored value. + APInt getValue() const { + if (getType().isIndex()) + return APInt(64, {getTrailingObjects<uint64_t>(), numObjects}); + return APInt(getType().getIntOrFloatBitWidth(), + {getTrailingObjects<uint64_t>(), numObjects}); } - DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {} - /// This is the number of attributes. - const unsigned numElements; + size_t numObjects; }; -// An attribute representing a reference to an affine map. -struct AffineMapAttributeStorage : public AttributeStorage { - using KeyTy = AffineMap; +// An attribute representing a reference to an integer set. +struct IntegerSetAttributeStorage : public AttributeStorage { + using KeyTy = IntegerSet; - AffineMapAttributeStorage(AffineMap value) - : AttributeStorage(IndexType::get(value.getContext())), value(value) {} + IntegerSetAttributeStorage(IntegerSet value) : value(value) {} /// Key equality function. bool operator==(const KeyTy &key) const { return key == value; } /// Construct a new storage instance. - static AffineMapAttributeStorage * + static IntegerSetAttributeStorage * construct(AttributeStorageAllocator &allocator, KeyTy key) { - return new (allocator.allocate<AffineMapAttributeStorage>()) - AffineMapAttributeStorage(key); + return new (allocator.allocate<IntegerSetAttributeStorage>()) + IntegerSetAttributeStorage(key); } - AffineMap value; + IntegerSet value; }; -// An attribute representing a reference to an integer set. -struct IntegerSetAttributeStorage : public AttributeStorage { - using KeyTy = IntegerSet; +/// Opaque Attribute Storage and Uniquing. +struct OpaqueAttributeStorage : public AttributeStorage { + OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData) + : dialectNamespace(dialectNamespace), attrData(attrData) {} - IntegerSetAttributeStorage(IntegerSet value) : value(value) {} + /// The hash key used for uniquing. + using KeyTy = std::pair<Identifier, StringRef>; + bool operator==(const KeyTy &key) const { + return key == KeyTy(dialectNamespace, attrData); + } + + static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate<OpaqueAttributeStorage>()) + OpaqueAttributeStorage(key.first, allocator.copyInto(key.second)); + } + + // The dialect namespace. + Identifier dialectNamespace; + + // The parser attribute data for this opaque attribute. + StringRef attrData; +}; + +/// An attribute representing a string value. +struct StringAttributeStorage : public AttributeStorage { + using KeyTy = StringRef; + + StringAttributeStorage(StringRef value) : value(value) {} /// Key equality function. bool operator==(const KeyTy &key) const { return key == value; } /// Construct a new storage instance. - static IntegerSetAttributeStorage * - construct(AttributeStorageAllocator &allocator, KeyTy key) { - return new (allocator.allocate<IntegerSetAttributeStorage>()) - IntegerSetAttributeStorage(key); + static StringAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate<StringAttributeStorage>()) + StringAttributeStorage(allocator.copyInto(key)); } - IntegerSet value; + StringRef value; }; /// An attribute representing a reference to a type. @@ -334,28 +333,9 @@ struct TypeAttributeStorage : public AttributeStorage { Type value; }; -/// An attribute representing a reference to a vector or tensor constant, -/// inwhich all elements have the same value. -struct SplatElementsAttributeStorage : public AttributeStorage { - using KeyTy = std::pair<Type, Attribute>; - - SplatElementsAttributeStorage(Type type, Attribute elt) - : AttributeStorage(type), elt(elt) {} - - /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { - return key == std::make_pair(getType(), elt); - } - - /// Construct a new storage instance. - static SplatElementsAttributeStorage * - construct(AttributeStorageAllocator &allocator, KeyTy key) { - return new (allocator.allocate<SplatElementsAttributeStorage>()) - SplatElementsAttributeStorage(key.first, key.second); - } - - Attribute elt; -}; +//===----------------------------------------------------------------------===// +// Elements Attributes +//===----------------------------------------------------------------------===// /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index bbf89958213..ce33508830c 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -57,41 +57,26 @@ MLIRContext *Attribute::getContext() const { return getType().getContext(); } Dialect &Attribute::getDialect() const { return impl->getDialect(); } //===----------------------------------------------------------------------===// -// OpaqueAttr +// AffineMapAttr //===----------------------------------------------------------------------===// -OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, - MLIRContext *context) { - return Base::get(context, StandardAttributes::Opaque, dialect, attrData); -} - -OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, - MLIRContext *context, Location location) { - return Base::getChecked(location, context, StandardAttributes::Opaque, - dialect, attrData); +AffineMapAttr AffineMapAttr::get(AffineMap value) { + return Base::get(value.getResult(0).getContext(), + StandardAttributes::AffineMap, value); } -/// Returns the dialect namespace of the opaque attribute. -Identifier OpaqueAttr::getDialectNamespace() const { - return getImpl()->dialectNamespace; -} +AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } -/// Returns the raw attribute data of the opaque attribute. -StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } +//===----------------------------------------------------------------------===// +// ArrayAttr +//===----------------------------------------------------------------------===// -/// Verify the construction of an opaque attribute. -LogicalResult OpaqueAttr::verifyConstructionInvariants( - llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect, - StringRef attrData) { - if (!Dialect::isValidNamespace(dialect.strref())) { - if (loc) - context->emitError(*loc) - << "invalid dialect namespace '" << dialect << "'"; - return failure(); - } - return success(); +ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { + return Base::get(context, StandardAttributes::Array, value); } +ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } + //===----------------------------------------------------------------------===// // BoolAttr //===----------------------------------------------------------------------===// @@ -196,27 +181,6 @@ DictionaryAttr::iterator DictionaryAttr::end() const { size_t DictionaryAttr::size() const { return getValue().size(); } //===----------------------------------------------------------------------===// -// IntegerAttr -//===----------------------------------------------------------------------===// - -IntegerAttr IntegerAttr::get(Type type, const APInt &value) { - return Base::get(type.getContext(), StandardAttributes::Integer, type, value); -} - -IntegerAttr IntegerAttr::get(Type type, int64_t value) { - // This uses 64 bit APInts by default for index type. - if (type.isIndex()) - return get(type, APInt(64, value)); - - auto intType = type.cast<IntegerType>(); - return get(type, APInt(intType.getWidth(), value)); -} - -APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } - -int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } - -//===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// @@ -287,35 +251,40 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc, } //===----------------------------------------------------------------------===// -// StringAttr +// FunctionAttr //===----------------------------------------------------------------------===// -StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { - return Base::get(context, StandardAttributes::String, bytes); +FunctionAttr FunctionAttr::get(Function *value) { + assert(value && "Cannot get FunctionAttr for a null function"); + return get(value->getName(), value->getContext()); } -StringRef StringAttr::getValue() const { return getImpl()->value; } +FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) { + return Base::get(ctx, StandardAttributes::Function, value); +} + +StringRef FunctionAttr::getValue() const { return getImpl()->value; } //===----------------------------------------------------------------------===// -// ArrayAttr +// IntegerAttr //===----------------------------------------------------------------------===// -ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { - return Base::get(context, StandardAttributes::Array, value); +IntegerAttr IntegerAttr::get(Type type, const APInt &value) { + return Base::get(type.getContext(), StandardAttributes::Integer, type, value); } -ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } - -//===----------------------------------------------------------------------===// -// AffineMapAttr -//===----------------------------------------------------------------------===// +IntegerAttr IntegerAttr::get(Type type, int64_t value) { + // This uses 64 bit APInts by default for index type. + if (type.isIndex()) + return get(type, APInt(64, value)); -AffineMapAttr AffineMapAttr::get(AffineMap value) { - return Base::get(value.getResult(0).getContext(), - StandardAttributes::AffineMap, value); + auto intType = type.cast<IntegerType>(); + return get(type, APInt(intType.getWidth(), value)); } -AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } +APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } + +int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } //===----------------------------------------------------------------------===// // IntegerSetAttr @@ -329,29 +298,60 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } //===----------------------------------------------------------------------===// -// TypeAttr +// OpaqueAttr //===----------------------------------------------------------------------===// -TypeAttr TypeAttr::get(Type value) { - return Base::get(value.getContext(), StandardAttributes::Type, value); +OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, + MLIRContext *context) { + return Base::get(context, StandardAttributes::Opaque, dialect, attrData); } -Type TypeAttr::getValue() const { return getImpl()->value; } +OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, + MLIRContext *context, Location location) { + return Base::getChecked(location, context, StandardAttributes::Opaque, + dialect, attrData); +} + +/// Returns the dialect namespace of the opaque attribute. +Identifier OpaqueAttr::getDialectNamespace() const { + return getImpl()->dialectNamespace; +} + +/// Returns the raw attribute data of the opaque attribute. +StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } + +/// Verify the construction of an opaque attribute. +LogicalResult OpaqueAttr::verifyConstructionInvariants( + llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect, + StringRef attrData) { + if (!Dialect::isValidNamespace(dialect.strref())) { + if (loc) + context->emitError(*loc) + << "invalid dialect namespace '" << dialect << "'"; + return failure(); + } + return success(); +} //===----------------------------------------------------------------------===// -// FunctionAttr +// StringAttr //===----------------------------------------------------------------------===// -FunctionAttr FunctionAttr::get(Function *value) { - assert(value && "Cannot get FunctionAttr for a null function"); - return get(value->getName(), value->getContext()); +StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { + return Base::get(context, StandardAttributes::String, bytes); } -FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) { - return Base::get(ctx, StandardAttributes::Function, value); +StringRef StringAttr::getValue() const { return getImpl()->value; } + +//===----------------------------------------------------------------------===// +// TypeAttr +//===----------------------------------------------------------------------===// + +TypeAttr TypeAttr::get(Type value) { + return Base::get(value.getContext(), StandardAttributes::Type, value); } -StringRef FunctionAttr::getValue() const { return getImpl()->value; } +Type TypeAttr::getValue() const { return getImpl()->value; } //===----------------------------------------------------------------------===// // ElementsAttr @@ -399,28 +399,6 @@ ElementsAttr ElementsAttr::mapValues( } //===----------------------------------------------------------------------===// -// SplatElementsAttr -//===----------------------------------------------------------------------===// - -SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { - return DenseElementsAttr::get(type, elt).cast<SplatElementsAttr>(); -} - -SplatElementsAttr SplatElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const { - return DenseElementsAttr::mapValues(newElementType, mapping) - .cast<SplatElementsAttr>(); -} - -SplatElementsAttr SplatElementsAttr::mapValues( - Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const { - return DenseElementsAttr::mapValues(newElementType, mapping) - .cast<SplatElementsAttr>(); -} - -//===----------------------------------------------------------------------===// // DenseElementAttr Utilities //===----------------------------------------------------------------------===// @@ -787,7 +765,7 @@ DenseElementsAttr DenseElementsAttr::mapValues( } //===----------------------------------------------------------------------===// -// DenseIntElementsAttr +// DenseFPElementsAttr //===----------------------------------------------------------------------===// template <typename Fn, typename Attr> @@ -820,9 +798,9 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, return newArrayType; } -DenseElementsAttr DenseIntElementsAttr::mapValues( +DenseElementsAttr DenseFPElementsAttr::mapValues( Type newElementType, - llvm::function_ref<APInt(const APInt &)> mapping) const { + llvm::function_ref<APInt(const APFloat &)> mapping) const { llvm::SmallVector<char, 8> elementData; auto newArrayType = mappingHelper(mapping, *this, getType(), newElementType, elementData); @@ -831,18 +809,18 @@ DenseElementsAttr DenseIntElementsAttr::mapValues( } /// Method for supporting type inquiry through isa, cast and dyn_cast. -bool DenseIntElementsAttr::classof(Attribute attr) { +bool DenseFPElementsAttr::classof(Attribute attr) { return attr.isa<DenseElementsAttr>() && - attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); + attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); } //===----------------------------------------------------------------------===// -// DenseFPElementsAttr +// DenseIntElementsAttr //===----------------------------------------------------------------------===// -DenseElementsAttr DenseFPElementsAttr::mapValues( +DenseElementsAttr DenseIntElementsAttr::mapValues( Type newElementType, - llvm::function_ref<APInt(const APFloat &)> mapping) const { + llvm::function_ref<APInt(const APInt &)> mapping) const { llvm::SmallVector<char, 8> elementData; auto newArrayType = mappingHelper(mapping, *this, getType(), newElementType, elementData); @@ -851,9 +829,9 @@ DenseElementsAttr DenseFPElementsAttr::mapValues( } /// Method for supporting type inquiry through isa, cast and dyn_cast. -bool DenseFPElementsAttr::classof(Attribute attr) { +bool DenseIntElementsAttr::classof(Attribute attr) { return attr.isa<DenseElementsAttr>() && - attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); + attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); } //===----------------------------------------------------------------------===// @@ -963,6 +941,28 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { } //===----------------------------------------------------------------------===// +// SplatElementsAttr +//===----------------------------------------------------------------------===// + +SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { + return DenseElementsAttr::get(type, elt).cast<SplatElementsAttr>(); +} + +SplatElementsAttr SplatElementsAttr::mapValues( + Type newElementType, + llvm::function_ref<APInt(const APInt &)> mapping) const { + return DenseElementsAttr::mapValues(newElementType, mapping) + .cast<SplatElementsAttr>(); +} + +SplatElementsAttr SplatElementsAttr::mapValues( + Type newElementType, + llvm::function_ref<APInt(const APFloat &)> mapping) const { + return DenseElementsAttr::mapValues(newElementType, mapping) + .cast<SplatElementsAttr>(); +} + +//===----------------------------------------------------------------------===// // NamedAttributeList //===----------------------------------------------------------------------===// |