//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to lower attributes that specify the shader ABI // for the functions in the generated SPIR-V module. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SetVector.h" using namespace mlir; /// Checks if the `type` is a scalar or vector type. It is assumed that they are /// valid for SPIR-V dialect already. static bool isScalarOrVectorType(Type type) { return spirv::SPIRVDialect::isValidScalarType(type) || type.isa(); } /// Creates a global variable for an argument based on the ABI info. static spirv::GlobalVariableOp createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum, spirv::InterfaceVarABIAttr abiInfo) { auto spirvModule = funcOp.getParentOfType(); if (!spirvModule) { return nullptr; } OpBuilder::InsertionGuard moduleInsertionGuard(builder); builder.setInsertionPoint(funcOp.getOperation()); std::string varName = funcOp.getName().str() + "_arg_" + std::to_string(argNum); // Get the type of variable. If this is a scalar/vector type and has an ABI // info create a variable of type !spv.ptr>. If not // it must already be a !spv.ptr>. auto varType = funcOp.getType().getInput(argNum); auto storageClass = static_cast(abiInfo.storage_class().getInt()); if (isScalarOrVectorType(varType)) { varType = spirv::PointerType::get(spirv::StructType::get(varType), storageClass); } auto varPtrType = varType.cast(); auto varPointeeType = varPtrType.getPointeeType().cast(); // Set the offset information. VulkanLayoutUtils::Size size = 0, alignment = 0; varPointeeType = VulkanLayoutUtils::decorateType(varPointeeType, size, alignment) .cast(); varType = spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); return builder.create( funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(), abiInfo.binding().getInt()); } /// Gets the global variables that need to be specified as interface variable /// with an spv.EntryPointOp. Traverses the body of a entry function to do so. static LogicalResult getInterfaceVariables(FuncOp funcOp, SmallVectorImpl &interfaceVars) { auto module = funcOp.getParentOfType(); if (!module) { return failure(); } llvm::SetVector interfaceVarSet; // TODO(ravishankarm) : This should in reality traverse the entry function // call graph and collect all the interfaces. For now, just traverse the // instructions in this function. funcOp.walk([&](spirv::AddressOfOp addressOfOp) { auto var = module.lookupSymbol(addressOfOp.variable()); if (var.type().cast().getStorageClass() != spirv::StorageClass::StorageBuffer) { interfaceVarSet.insert(var.getOperation()); } }); for (auto &var : interfaceVarSet) { interfaceVars.push_back(SymbolRefAttr::get( cast(var).sym_name(), funcOp.getContext())); } return success(); } /// Lowers the entry point attribute. static LogicalResult lowerEntryPointABIAttr(FuncOp funcOp, OpBuilder &builder) { auto entryPointAttrName = spirv::getEntryPointABIAttrName(); auto entryPointAttr = funcOp.getAttrOfType(entryPointAttrName); if (!entryPointAttr) { return failure(); } OpBuilder::InsertionGuard moduleInsertionGuard(builder); auto spirvModule = funcOp.getParentOfType(); builder.setInsertionPoint(spirvModule.body().front().getTerminator()); // Adds the spv.EntryPointOp after collecting all the interface variables // needed. SmallVector interfaceVars; if (failed(getInterfaceVariables(funcOp, interfaceVars))) { return failure(); } builder.create( funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars); // Specifies the spv.ExecutionModeOp. auto localSizeAttr = entryPointAttr.local_size(); SmallVector localSize(localSizeAttr.getValues()); builder.create( funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize); funcOp.removeAttr(entryPointAttrName); return success(); } namespace { /// Pattern rewriter for changing function signature to match the ABI specified /// in attributes. class FuncOpLowering final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Pass to implement the ABI information specified as attributes. class LowerABIAttributesPass final : public OperationPass { private: void runOnOperation() override; }; } // namespace PatternMatchResult FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!funcOp.getAttrOfType( spirv::getEntryPointABIAttrName())) { // TODO(ravishankarm) : Non-entry point functions are not handled. return matchFailure(); } TypeConverter::SignatureConversion signatureConverter( funcOp.getType().getNumInputs()); auto attrName = spirv::getInterfaceVarABIAttrName(); for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) { auto abiInfo = funcOp.getArgAttrOfType( argType.index(), attrName); if (!abiInfo) { // TODO(ravishankarm) : For non-entry point functions, it should be legal // to pass around scalar/vector values and return a scalar/vector. For now // non-entry point functions are not handled in this ABI lowering and will // produce an error. return matchFailure(); } auto var = createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo); if (!var) { return matchFailure(); } OpBuilder::InsertionGuard funcInsertionGuard(rewriter); rewriter.setInsertionPointToStart(&funcOp.front()); // Insert spirv::AddressOf and spirv::AccessChain operations. Value replacement = rewriter.create(funcOp.getLoc(), var); // Check if the arg is a scalar or vector type. In that case, the value // needs to be loaded into registers. // TODO(ravishankarm) : This is loading value of the scalar into registers // at the start of the function. It is probably better to do the load just // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. if (isScalarOrVectorType(argType.value())) { auto indexType = typeConverter.convertType(IndexType::get(funcOp.getContext())); auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter); auto loadPtr = rewriter.create( funcOp.getLoc(), replacement, zero.constant()); replacement = rewriter.create(funcOp.getLoc(), loadPtr, /*memory_access=*/nullptr, /*alignment=*/nullptr); } signatureConverter.remapInput(argType.index(), replacement); } // Creates a new function with the update signature. rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), llvm::None)); rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); }); return matchSuccess(); } void LowerABIAttributesPass::runOnOperation() { // Uses the signature conversion methodology of the dialect conversion // framework to implement the conversion. spirv::ModuleOp module = getOperation(); MLIRContext *context = &getContext(); SPIRVTypeConverter typeConverter; OwningRewritePatternList patterns; patterns.insert(context, typeConverter); std::unique_ptr target = spirv::SPIRVConversionTarget::get( spirv::lookupTargetEnvOrDefault(module), context); auto entryPointAttrName = spirv::getEntryPointABIAttrName(); target->addDynamicallyLegalOp([&](FuncOp op) { return op.getAttrOfType(entryPointAttrName) && op.getNumResults() == 0 && op.getNumArguments() == 0; }); target->addLegalOp(); if (failed( applyPartialConversion(module, *target, patterns, &typeConverter))) { return signalPassFailure(); } // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point // attributes. OpBuilder builder(context); SmallVector entryPointFns; module.walk([&](FuncOp funcOp) { if (funcOp.getAttrOfType(entryPointAttrName)) { entryPointFns.push_back(funcOp); } }); for (auto fn : entryPointFns) { if (failed(lowerEntryPointABIAttr(fn, builder))) { return signalPassFailure(); } } } std::unique_ptr> mlir::spirv::createLowerABIAttributesPass() { return std::make_unique(); } static PassRegistration pass("spirv-lower-abi-attrs", "Lower SPIR-V ABI Attributes");