diff options
author | Mehdi Amini <aminim@google.com> | 2019-12-24 02:47:41 +0000 |
---|---|---|
committer | Mehdi Amini <aminim@google.com> | 2019-12-24 02:47:41 +0000 |
commit | 0f0d0ed1c78f1a80139a1f2133fad5284691a121 (patch) | |
tree | 31979a3137c364e3eb58e0169a7c4029c7ee7db3 /mlir/lib/Dialect/SPIRV/LayoutUtils.cpp | |
parent | 6f635f90929da9545dd696071a829a1a42f84b30 (diff) | |
parent | 5b4a01d4a63cb66ab981e52548f940813393bf42 (diff) | |
download | bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.tar.gz bcm5719-llvm-0f0d0ed1c78f1a80139a1f2133fad5284691a121.zip |
Import MLIR into the LLVM tree
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/LayoutUtils.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/LayoutUtils.cpp | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp new file mode 100644 index 00000000000..a12d04edd68 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -0,0 +1,156 @@ +//===-- LayoutUtils.cpp - Decorate composite type with layout information -===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements Utilities used to get alignment and layout information +// for types in SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/LayoutUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" + +using namespace mlir; + +spirv::StructType +VulkanLayoutUtils::decorateType(spirv::StructType structType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + if (structType.getNumElements() == 0) { + return structType; + } + + SmallVector<Type, 4> memberTypes; + SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo; + SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; + + VulkanLayoutUtils::Size structMemberOffset = 0; + VulkanLayoutUtils::Size maxMemberAlignment = 1; + + for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { + VulkanLayoutUtils::Size memberSize = 0; + VulkanLayoutUtils::Size memberAlignment = 1; + + auto memberType = VulkanLayoutUtils::decorateType( + structType.getElementType(i), memberSize, memberAlignment); + structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); + memberTypes.push_back(memberType); + layoutInfo.push_back(structMemberOffset); + // According to the Vulkan spec: + // "A structure has a base alignment equal to the largest base alignment of + // any of its members." + structMemberOffset += memberSize; + maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment); + } + + // According to the Vulkan spec: + // "The Offset decoration of a member must not place it between the end of a + // structure or an array and the next multiple of the alignment of that + // structure or array." + size = llvm::alignTo(structMemberOffset, maxMemberAlignment); + alignment = maxMemberAlignment; + structType.getMemberDecorations(memberDecorations); + return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations); +} + +Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + if (spirv::SPIRVDialect::isValidScalarType(type)) { + alignment = VulkanLayoutUtils::getScalarTypeAlignment(type); + // Vulkan spec does not specify any padding for a scalar type. + size = alignment; + return type; + } + + switch (type.getKind()) { + case spirv::TypeKind::Struct: + return VulkanLayoutUtils::decorateType(type.cast<spirv::StructType>(), size, + alignment); + case spirv::TypeKind::Array: + return VulkanLayoutUtils::decorateType(type.cast<spirv::ArrayType>(), size, + alignment); + case StandardTypes::Vector: + return VulkanLayoutUtils::decorateType(type.cast<VectorType>(), size, + alignment); + default: + llvm_unreachable("unhandled SPIR-V type"); + } +} + +Type VulkanLayoutUtils::decorateType(VectorType vectorType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + const auto numElements = vectorType.getNumElements(); + auto elementType = vectorType.getElementType(); + VulkanLayoutUtils::Size elementSize = 0; + VulkanLayoutUtils::Size elementAlignment = 1; + + auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, + elementAlignment); + // According to the Vulkan spec: + // 1. "A two-component vector has a base alignment equal to twice its scalar + // alignment." + // 2. "A three- or four-component vector has a base alignment equal to four + // times its scalar alignment." + size = elementSize * numElements; + alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; + return VectorType::get(numElements, memberType); +} + +Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + const auto numElements = arrayType.getNumElements(); + auto elementType = arrayType.getElementType(); + spirv::ArrayType::LayoutInfo elementSize = 0; + VulkanLayoutUtils::Size elementAlignment = 1; + + auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, + elementAlignment); + // According to the Vulkan spec: + // "An array has a base alignment equal to the base alignment of its element + // type." + size = elementSize * numElements; + alignment = elementAlignment; + return spirv::ArrayType::get(memberType, numElements, elementSize); +} + +VulkanLayoutUtils::Size +VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { + // According to the Vulkan spec: + // 1. "A scalar of size N has a scalar alignment of N." + // 2. "A scalar has a base alignment equal to its scalar alignment." + // 3. "A scalar, vector or matrix type has an extended alignment equal to its + // base alignment." + auto bitWidth = scalarType.getIntOrFloatBitWidth(); + if (bitWidth == 1) + return 1; + return bitWidth / 8; +} + +bool VulkanLayoutUtils::isLegalType(Type type) { + auto ptrType = type.dyn_cast<spirv::PointerType>(); + if (!ptrType) { + return true; + } + + auto storageClass = ptrType.getStorageClass(); + auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>(); + if (!structType) { + return true; + } + + switch (storageClass) { + case spirv::StorageClass::Uniform: + case spirv::StorageClass::StorageBuffer: + case spirv::StorageClass::PushConstant: + case spirv::StorageClass::PhysicalStorageBuffer: + return structType.hasLayout() || !structType.getNumElements(); + default: + return true; + } +} |