//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities used to lower to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "mlir-spirv-lowering" using namespace mlir; //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { // Convert to 32-bit integers for now. Might need a way to control this in // future. // TODO(ravishankarm): It is probably better to make it 64-bit integers. To // this some support is needed in SPIR-V dialect for Conversion // instructions. The Vulkan spec requires the builtins like // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be // SExtended to 64-bit for index computations. return IntegerType::get(32, context); } // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(Type t) { if (auto integerType = t.dyn_cast()) { return integerType.getWidth() / 8; } else if (auto floatType = t.dyn_cast()) { return floatType.getWidth() / 8; } else if (auto memRefType = t.dyn_cast()) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; SmallVector strides; if (!memRefType.hasStaticShape() || failed(getStridesAndOffset(memRefType, strides, offset))) { return llvm::None; } // To get the size of the memref object in memory, the total size is the // max(stride * dimension-size) computed for all dimensions times the size // of the element. auto elementSize = getTypeNumBytes(memRefType.getElementType()); if (!elementSize) { return llvm::None; } auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamicSize) || offset == MemRefType::getDynamicStrideOrOffset() || llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { return llvm::None; } int64_t memrefSize = -1; for (auto shape : enumerate(dims)) { memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); } return (offset + memrefSize) * elementSize.getValue(); } // TODO: Add size computation for other types. return llvm::None; } static Type convertStdType(Type type) { // If the type is already valid in SPIR-V, directly return. if (spirv::SPIRVDialect::isValidType(type)) { return type; } if (auto indexType = type.dyn_cast()) { return SPIRVTypeConverter::getIndexType(type.getContext()); } if (auto memRefType = type.dyn_cast()) { // TODO(ravishankarm): For now only support default memory space. The memory // space description is not set is stone within MLIR, i.e. it depends on the // context it is being used. To map this to SPIR-V storage classes, we // should rely on the ABI attributes, and not on the memory space. This is // still evolving, and needs to be revisited when there is more clarity. if (memRefType.getMemorySpace()) { return Type(); } auto elementType = convertStdType(memRefType.getElementType()); if (!elementType) { return Type(); } auto elementSize = getTypeNumBytes(elementType); if (!elementSize) { return Type(); } // TODO(ravishankarm) : Handle dynamic shapes. if (memRefType.hasStaticShape()) { auto arraySize = getTypeNumBytes(memRefType); if (!arraySize) { return Type(); } auto arrayType = spirv::ArrayType::get( elementType, arraySize.getValue() / elementSize.getValue(), elementSize.getValue()); auto structType = spirv::StructType::get(arrayType, 0); // For now initialize the storage class to StorageBuffer. This will be // updated later based on whats passed in w.r.t to the ABI attributes. return spirv::PointerType::get(structType, spirv::StorageClass::StorageBuffer); } } return Type(); } Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); } //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// /// Look through all global variables in `moduleOp` and check if there is a /// spv.globalVariable that has the same `builtin` attribute. static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, spirv::BuiltIn builtin) { for (auto varOp : moduleOp.getBlock().getOps()) { if (auto builtinAttr = varOp.getAttrOfType( spirv::SPIRVDialect::getAttributeName( spirv::Decoration::BuiltIn))) { auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); if (varBuiltIn && varBuiltIn.getValue() == builtin) { return varOp; } } } return nullptr; } /// Gets name of global variable for a builtin. static std::string getBuiltinVarName(spirv::BuiltIn builtin) { return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; } /// Gets or inserts a global variable for a builtin within a module. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, spirv::BuiltIn builtin, OpBuilder &builder) { if (auto varOp = getBuiltinVariable(moduleOp, builtin)) { return varOp; } auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(&moduleOp.getBlock()); auto name = getBuiltinVarName(builtin); spirv::GlobalVariableOp newVarOp; switch (builtin) { case spirv::BuiltIn::NumWorkgroups: case spirv::BuiltIn::WorkgroupSize: case spirv::BuiltIn::WorkgroupId: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { auto ptrType = spirv::PointerType::get( VectorType::get({3}, builder.getIntegerType(32)), spirv::StorageClass::Input); newVarOp = builder.create(loc, ptrType, name, builtin); break; } default: emitError(loc, "unimplemented builtin variable generation for ") << stringifyBuiltIn(builtin); } builder.restoreInsertionPoint(ip); return newVarOp; } /// Gets the global variable associated with a builtin and add /// it if it doesn't exist. Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, OpBuilder &builder) { auto moduleOp = op->getParentOfType(); if (!moduleOp) { op->emitError("expected operation to be within a SPIR-V module"); return nullptr; } spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr); } //===----------------------------------------------------------------------===// // Set ABI attributes for lowering entry functions. //===----------------------------------------------------------------------===// LogicalResult mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo) { // Set the attributes for argument and the function. StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); for (auto argIndex : llvm::seq(0, funcOp.getNumArguments())) { funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); } funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); return success(); } //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// std::unique_ptr spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv, MLIRContext *context) { std::unique_ptr target( // std::make_unique does not work here because the constructor is private. new SPIRVConversionTarget(targetEnv, context)); SPIRVConversionTarget *targetPtr = target.get(); target->addDynamicallyLegalDialect( Optional( // We need to capture the raw pointer here because it is stable: // target will be destroyed once this function is returned. [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); })); return target; } spirv::SPIRVConversionTarget::SPIRVConversionTarget( spirv::TargetEnvAttr targetEnv, MLIRContext *context) : ConversionTarget(*context), givenVersion(static_cast(targetEnv.version().getInt())) { for (Attribute extAttr : targetEnv.extensions()) givenExtensions.insert( *spirv::symbolizeExtension(extAttr.cast().getValue())); for (Attribute capAttr : targetEnv.capabilities()) givenCapabilities.insert( static_cast(capAttr.cast().getInt())); } bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. if (auto minVersion = dyn_cast(op)) if (minVersion.getMinVersion() > givenVersion) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring min version " << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"); return false; } if (auto maxVersion = dyn_cast(op)) if (maxVersion.getMaxVersion() < givenVersion) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring max version " << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"); return false; } // Make sure this op's required extensions are allowed to use. For each op, // we return a vector of vector for its extension requirements following // ((Extension::A OR Extenion::B) AND (Extension::C OR Extension::D)) // convention. Ops not implementing QueryExtensionInterface do not require // extensions to be available. if (auto extensions = dyn_cast(op)) { auto exts = extensions.getExtensions(); for (const auto &ors : exts) if (llvm::all_of(ors, [this](spirv::Extension ext) { return this->givenExtensions.count(ext) == 0; })) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: missing required extension\n"); return false; } } // Make sure this op's required extensions are allowed to use. For each op, // we return a vector of vector for its capability requirements following // ((Capability::A OR Extenion::B) AND (Capability::C OR Capability::D)) // convention. Ops not implementing QueryExtensionInterface do not require // extensions to be available. if (auto capabilities = dyn_cast(op)) { auto caps = capabilities.getCapabilities(); for (const auto &ors : caps) if (llvm::all_of(ors, [this](spirv::Capability cap) { return this->givenCapabilities.count(cap) == 0; })) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: missing required capability\n"); return false; } } return true; };