//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Standard Ops to the SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineMap.h" #include "llvm/ADT/SetVector.h" using namespace mlir; //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { /// Convert constant operation with IndexType return to SPIR-V constant /// operation. Since IndexType is not used within SPIR-V dialect, this needs /// special handling to make sure the result type and the type of the value /// attribute are consistent. // TODO(ravishankarm) : This should be moved into DRR. class ConstantIndexOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Convert floating-point comparison operations to SPIR-V dialect. class CmpFOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Convert compare operation to SPIR-V dialect. class CmpIOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Convert integer binary operations to SPIR-V operations. Cannot use /// tablegen for this. If the integer operation is on variables of IndexType, /// the type of the return value of the replacement operation differs from /// that of the replaced operation. This is not handled in tablegen-based /// pattern specification. // TODO(ravishankarm) : This should be moved into DRR. template class IntegerOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = this->typeConverter.convertType(operation.getResult().getType()); rewriter.template replaceOpWithNewOp( operation, resultType, operands, ArrayRef()); return this->matchSuccess(); } }; /// Convert load -> spv.LoadOp. The operands of the replaced operation are of /// IndexType while that of the replacement operation are of type i32. This is /// not supported in tablegen based pattern specification. // TODO(ravishankarm) : This should be moved into DRR. class LoadOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Convert return -> spv.Return. // TODO(ravishankarm) : This should be moved into DRR. class ReturnOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Convert select -> spv.Select // TODO(ravishankarm) : This should be moved into DRR. class SelectOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Convert store -> spv.StoreOp. The operands of the replaced operation are /// of IndexType while that of the replacement operation are of type i32. This /// is not supported in tablegen based pattern specification. // TODO(ravishankarm) : This should be moved into DRR. class StoreOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // Utility functions for operation conversion //===----------------------------------------------------------------------===// /// Performs the index computation to get to the element pointed to by /// `indices` using the layout map of `baseType`. // TODO(ravishankarm) : This method assumes that the `origBaseType` is a // MemRefType with AffineMap that has static strides. Handle dynamic strides static spirv::AccessChainOp getElementPtr(OpBuilder &builder, SPIRVTypeConverter &typeConverter, Location loc, MemRefType origBaseType, Value basePtr, ArrayRef indices) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(origBaseType, strides, offset)) || llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { return nullptr; } auto indexType = typeConverter.getIndexType(builder.getContext()); Value ptrLoc = nullptr; assert(indices.size() == strides.size()); for (auto index : enumerate(indices)) { Value strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); Value update = builder.create(loc, strideVal, index.value()); ptrLoc = (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() : update); } SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. linearizedIndices.push_back(builder.create( loc, indexType, IntegerAttr::get(indexType, 0))); linearizedIndices.push_back(ptrLoc); return builder.create(loc, basePtr, linearizedIndices); } //===----------------------------------------------------------------------===// // ConstantOp with index type. //===----------------------------------------------------------------------===// PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!constIndexOp.getResult().getType().isa()) { return matchFailure(); } // The attribute has index type which is not directly supported in // SPIR-V. Get the integer value and create a new IntegerAttr. auto constAttr = constIndexOp.value().dyn_cast(); if (!constAttr) { return matchFailure(); } // Use the bitwidth set in the value attribute to decide the result type // of the SPIR-V constant operation since SPIR-V does not support index // types. auto constVal = constAttr.getValue(); auto constValType = constAttr.getType().dyn_cast(); if (!constValType) { return matchFailure(); } auto spirvConstType = typeConverter.convertType(constIndexOp.getResult().getType()); auto spirvConstVal = rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); rewriter.replaceOpWithNewOp(constIndexOp, spirvConstType, spirvConstVal); return matchSuccess(); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// PatternMatchResult CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpFOpOperandAdaptor cmpFOpOperands(operands); switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ cmpFOpOperands.lhs(), \ cmpFOpOperands.rhs()); \ return matchSuccess(); // Ordered. DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); #undef DISPATCH default: break; } return matchFailure(); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// PatternMatchResult CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ return matchSuccess(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH } return matchFailure(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// PatternMatchResult LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), loadOp.memref().getType().cast(), loadOperands.memref(), loadOperands.indices()); rewriter.replaceOpWithNewOp(loadOp, loadPtr, /*memory_access =*/nullptr, /*alignment =*/nullptr); return matchSuccess(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// PatternMatchResult ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands()) { return matchFailure(); } rewriter.replaceOpWithNewOp(returnOp); return matchSuccess(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// PatternMatchResult SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { SelectOpOperandAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), selectOperands.false_value()); return matchSuccess(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// PatternMatchResult StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); auto storePtr = getElementPtr(rewriter, typeConverter, storeOp.getLoc(), storeOp.memref().getType().cast(), storeOperands.memref(), storeOperands.indices()); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value(), /*memory_access =*/nullptr, /*alignment =*/nullptr); return matchSuccess(); } namespace { /// Import the Standard Ops to SPIR-V Patterns. #include "StandardToSPIRV.cpp.inc" } // namespace namespace mlir { void populateStandardToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); patterns.insert, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, LoadOpConversion, ReturnOpConversion, SelectOpConversion, StoreOpConversion>( context, typeConverter); } } // namespace mlir