//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard ops into the SPIR-V // ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Pass/Pass.h" using namespace mlir; namespace { /// A simple pattern for rewriting function signature to convert arguments of /// functions to be of valid SPIR-V types. class FuncOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// A pass converting MLIR Standard operations into the SPIR-V dialect. class ConvertStandardToSPIRVPass : public ModulePass { void runOnModule() override; }; } // namespace PatternMatchResult FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); if (fnType.getNumResults()) { return matchFailure(); } TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); { for (auto argType : enumerate(funcOp.getType().getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); signatureConverter.addInputs(argType.index(), convertedType); } } rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), llvm::None)); rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); }); return matchSuccess(); } void ConvertStandardToSPIRVPass::runOnModule() { MLIRContext *context = &getContext(); ModuleOp module = getModule(); SPIRVTypeConverter typeConverter; OwningRewritePatternList patterns; populateStandardToSPIRVPatterns(context, typeConverter, patterns); patterns.insert(context, typeConverter); std::unique_ptr target = spirv::SPIRVConversionTarget::get( spirv::lookupTargetEnvOrDefault(module), context); target->addDynamicallyLegalOp( [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(module, *target, patterns))) { return signalPassFailure(); } } std::unique_ptr> mlir::createConvertStandardToSPIRVPass() { return std::make_unique(); } static PassRegistration pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect");