//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the types in the SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; using namespace mlir::spirv; // Pull in all enum utility function definitions #include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc" // Pull in all enum type availability query function definitions #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc" //===----------------------------------------------------------------------===// // ArrayType //===----------------------------------------------------------------------===// struct spirv::detail::ArrayTypeStorage : public TypeStorage { using KeyTy = std::tuple; static ArrayTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) ArrayTypeStorage(key); } bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, getSubclassData(), layoutInfo); } ArrayTypeStorage(const KeyTy &key) : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)), layoutInfo(std::get<2>(key)) {} Type elementType; ArrayType::LayoutInfo layoutInfo; }; ArrayType ArrayType::get(Type elementType, unsigned elementCount) { assert(elementCount && "ArrayType needs at least one element"); return Base::get(elementType.getContext(), TypeKind::Array, elementType, elementCount, 0); } ArrayType ArrayType::get(Type elementType, unsigned elementCount, ArrayType::LayoutInfo layoutInfo) { assert(elementCount && "ArrayType needs at least one element"); return Base::get(elementType.getContext(), TypeKind::Array, elementType, elementCount, layoutInfo); } unsigned ArrayType::getNumElements() const { return getImpl()->getSubclassData(); } Type ArrayType::getElementType() const { return getImpl()->elementType; } // ArrayStride must be greater than zero bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; } uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; } //===----------------------------------------------------------------------===// // CompositeType //===----------------------------------------------------------------------===// bool CompositeType::classof(Type type) { switch (type.getKind()) { case TypeKind::Array: case TypeKind::RuntimeArray: case TypeKind::Struct: case StandardTypes::Vector: return true; default: return false; } } Type CompositeType::getElementType(unsigned index) const { switch (getKind()) { case spirv::TypeKind::Array: return cast().getElementType(); case spirv::TypeKind::RuntimeArray: return cast().getElementType(); case spirv::TypeKind::Struct: return cast().getElementType(index); case StandardTypes::Vector: return cast().getElementType(); default: llvm_unreachable("invalid composite type"); } } unsigned CompositeType::getNumElements() const { switch (getKind()) { case spirv::TypeKind::Array: return cast().getNumElements(); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); case spirv::TypeKind::Struct: return cast().getNumElements(); case StandardTypes::Vector: return cast().getNumElements(); default: llvm_unreachable("invalid composite type"); } } //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===// template static constexpr unsigned getNumBits() { return 0; } template <> constexpr unsigned getNumBits() { static_assert((1 << 3) > getMaxEnumValForDim(), "Not enough bits to encode Dim value"); return 3; } template <> constexpr unsigned getNumBits() { static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(), "Not enough bits to encode ImageDepthInfo value"); return 2; } template <> constexpr unsigned getNumBits() { static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(), "Not enough bits to encode ImageArrayedInfo value"); return 1; } template <> constexpr unsigned getNumBits() { static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(), "Not enough bits to encode ImageSamplingInfo value"); return 1; } template <> constexpr unsigned getNumBits() { static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(), "Not enough bits to encode ImageSamplerUseInfo value"); return 2; } template <> constexpr unsigned getNumBits() { static_assert((1 << 6) > getMaxEnumValForImageFormat(), "Not enough bits to encode ImageFormat value"); return 6; } struct spirv::detail::ImageTypeStorage : public TypeStorage { private: /// Define a bit-field struct to pack the enum values union EnumPack { struct { unsigned dimEncoding : getNumBits(); unsigned depthInfoEncoding : getNumBits(); unsigned arrayedInfoEncoding : getNumBits(); unsigned samplingInfoEncoding : getNumBits(); unsigned samplerUseInfoEncoding : getNumBits(); unsigned formatEncoding : getNumBits(); } data; unsigned storage; }; public: using KeyTy = std::tuple; static ImageTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) ImageTypeStorage(key); } bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, getDim(), getDepthInfo(), getArrayedInfo(), getSamplingInfo(), getSamplerUseInfo(), getImageFormat()); } Dim getDim() const { EnumPack v; v.storage = getSubclassData(); return static_cast(v.data.dimEncoding); } void setDim(Dim dim) { EnumPack v; v.storage = getSubclassData(); v.data.dimEncoding = static_cast(dim); setSubclassData(v.storage); } ImageDepthInfo getDepthInfo() const { EnumPack v; v.storage = getSubclassData(); return static_cast(v.data.depthInfoEncoding); } void setDepthInfo(ImageDepthInfo depthInfo) { EnumPack v; v.storage = getSubclassData(); v.data.depthInfoEncoding = static_cast(depthInfo); setSubclassData(v.storage); } ImageArrayedInfo getArrayedInfo() const { EnumPack v; v.storage = getSubclassData(); return static_cast(v.data.arrayedInfoEncoding); } void setArrayedInfo(ImageArrayedInfo arrayedInfo) { EnumPack v; v.storage = getSubclassData(); v.data.arrayedInfoEncoding = static_cast(arrayedInfo); setSubclassData(v.storage); } ImageSamplingInfo getSamplingInfo() const { EnumPack v; v.storage = getSubclassData(); return static_cast(v.data.samplingInfoEncoding); } void setSamplingInfo(ImageSamplingInfo samplingInfo) { EnumPack v; v.storage = getSubclassData(); v.data.samplingInfoEncoding = static_cast(samplingInfo); setSubclassData(v.storage); } ImageSamplerUseInfo getSamplerUseInfo() const { EnumPack v; v.storage = getSubclassData(); return static_cast(v.data.samplerUseInfoEncoding); } void setSamplerUseInfo(ImageSamplerUseInfo samplerUseInfo) { EnumPack v; v.storage = getSubclassData(); v.data.samplerUseInfoEncoding = static_cast(samplerUseInfo); setSubclassData(v.storage); } ImageFormat getImageFormat() const { EnumPack v; v.storage = getSubclassData(); return static_cast(v.data.formatEncoding); } void setImageFormat(ImageFormat format) { EnumPack v; v.storage = getSubclassData(); v.data.formatEncoding = static_cast(format); setSubclassData(v.storage); } ImageTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)) { static_assert(sizeof(EnumPack) <= sizeof(getSubclassData()), "EnumPack size greater than subClassData type size"); setDim(std::get<1>(key)); setDepthInfo(std::get<2>(key)); setArrayedInfo(std::get<3>(key)); setSamplingInfo(std::get<4>(key)); setSamplerUseInfo(std::get<5>(key)); setImageFormat(std::get<6>(key)); } Type elementType; }; ImageType ImageType::get(std::tuple value) { return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value); } Type ImageType::getElementType() const { return getImpl()->elementType; } Dim ImageType::getDim() const { return getImpl()->getDim(); } ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->getDepthInfo(); } ImageArrayedInfo ImageType::getArrayedInfo() const { return getImpl()->getArrayedInfo(); } ImageSamplingInfo ImageType::getSamplingInfo() const { return getImpl()->getSamplingInfo(); } ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { return getImpl()->getSamplerUseInfo(); } ImageFormat ImageType::getImageFormat() const { return getImpl()->getImageFormat(); } //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// struct spirv::detail::PointerTypeStorage : public TypeStorage { // (Type, StorageClass) as the key: Type stored in this struct, and // StorageClass stored as TypeStorage's subclass data. using KeyTy = std::pair; static PointerTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) PointerTypeStorage(key); } bool operator==(const KeyTy &key) const { return key == KeyTy(pointeeType, getStorageClass()); } PointerTypeStorage(const KeyTy &key) : TypeStorage(static_cast(key.second)), pointeeType(key.first) { } StorageClass getStorageClass() const { return static_cast(getSubclassData()); } Type pointeeType; }; PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType, storageClass); } Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } StorageClass PointerType::getStorageClass() const { return getImpl()->getStorageClass(); } //===----------------------------------------------------------------------===// // RuntimeArrayType //===----------------------------------------------------------------------===// struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage { using KeyTy = Type; static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) RuntimeArrayTypeStorage(key); } bool operator==(const KeyTy &key) const { return elementType == key; } RuntimeArrayTypeStorage(const KeyTy &key) : elementType(key) {} Type elementType; }; RuntimeArrayType RuntimeArrayType::get(Type elementType) { return Base::get(elementType.getContext(), TypeKind::RuntimeArray, elementType); } Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo) : TypeStorage(numMembers), memberTypes(memberTypes), layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations), memberDecorationsInfo(memberDecorationsInfo) {} using KeyTy = std::tuple, ArrayRef, ArrayRef>; bool operator==(const KeyTy &key) const { return key == KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo()); } static StructTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { ArrayRef keyTypes = std::get<0>(key); // Copy the member type and layout information into the bump pointer const Type *typesList = nullptr; if (!keyTypes.empty()) { typesList = allocator.copyInto(keyTypes).data(); } const StructType::LayoutInfo *layoutInfoList = nullptr; if (!std::get<1>(key).empty()) { ArrayRef keyLayoutInfo = std::get<1>(key); assert(keyLayoutInfo.size() == keyTypes.size() && "size of layout information must be same as the size of number of " "elements"); layoutInfoList = allocator.copyInto(keyLayoutInfo).data(); } const StructType::MemberDecorationInfo *memberDecorationList = nullptr; unsigned numMemberDecorations = 0; if (!std::get<2>(key).empty()) { auto keyMemberDecorations = std::get<2>(key); numMemberDecorations = keyMemberDecorations.size(); memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } return new (allocator.allocate()) StructTypeStorage(keyTypes.size(), typesList, layoutInfoList, numMemberDecorations, memberDecorationList); } ArrayRef getMemberTypes() const { return ArrayRef(memberTypes, getSubclassData()); } ArrayRef getLayoutInfo() const { if (layoutInfo) { return ArrayRef(layoutInfo, getSubclassData()); } return {}; } ArrayRef getMemberDecorationsInfo() const { if (memberDecorationsInfo) { return ArrayRef(memberDecorationsInfo, numMemberDecorations); } return {}; } Type const *memberTypes; StructType::LayoutInfo const *layoutInfo; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; }; StructType StructType::get(ArrayRef memberTypes, ArrayRef layoutInfo, ArrayRef memberDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. SmallVector sortedDecorations( memberDecorations.begin(), memberDecorations.end()); llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct, memberTypes, layoutInfo, sortedDecorations); } StructType StructType::getEmpty(MLIRContext *context) { return Base::get(context, TypeKind::Struct, ArrayRef(), ArrayRef(), ArrayRef()); } unsigned StructType::getNumElements() const { return getImpl()->getSubclassData(); } Type StructType::getElementType(unsigned index) const { assert( getNumElements() > index && "element index is more than number of members of the SPIR-V StructType"); return getImpl()->memberTypes[index]; } bool StructType::hasLayout() const { return getImpl()->layoutInfo; } uint64_t StructType::getOffset(unsigned index) const { assert( getNumElements() > index && "element index is more than number of members of the SPIR-V StructType"); return getImpl()->layoutInfo[index]; } void StructType::getMemberDecorations( SmallVectorImpl &memberDecorations) const { memberDecorations.clear(); auto implMemberDecorations = getImpl()->getMemberDecorationsInfo(); memberDecorations.append(implMemberDecorations.begin(), implMemberDecorations.end()); } void StructType::getMemberDecorations( unsigned index, SmallVectorImpl &decorations) const { assert(getNumElements() > index && "member index out of range"); auto memberDecorations = getImpl()->getMemberDecorationsInfo(); decorations.clear(); for (auto &memberDecoration : memberDecorations) { if (memberDecoration.first == index) { decorations.push_back(memberDecoration.second); } if (memberDecoration.first > index) { // Early exit since the decorations are stored sorted. return; } } }