diff options
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV')
5 files changed, 645 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt new file mode 100644 index 00000000000..fcced23a95e --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -0,0 +1,26 @@ +set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td) +mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRStandardToSPIRVIncGen) + +add_llvm_library(MLIRStandardToSPIRVTransforms + ConvertStandardToSPIRV.cpp + ConvertStandardToSPIRVPass.cpp + LegalizeStandardForSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + ) + +add_dependencies(MLIRStandardToSPIRVTransforms + MLIRStandardToSPIRVIncGen) + +target_link_libraries(MLIRStandardToSPIRVTransforms + MLIRIR + MLIRPass + MLIRSPIRV + MLIRSupport + MLIRTransformUtils + MLIRSPIRV + MLIRStandardOps + ) diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp new file mode 100644 index 00000000000..a02dee4419a --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -0,0 +1,314 @@ +//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Standard Ops to the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/LayoutUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineMap.h" +#include "llvm/ADT/SetVector.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Convert constant operation with IndexType return to SPIR-V constant +/// operation. Since IndexType is not used within SPIR-V dialect, this needs +/// special handling to make sure the result type and the type of the value +/// attribute are consistent. +// TODO(ravishankarm) : This should be moved into DRR. +class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> { +public: + using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert compare operation to SPIR-V dialect. +class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> { +public: + using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert integer binary operations to SPIR-V operations. Cannot use +/// tablegen for this. If the integer operation is on variables of IndexType, +/// the type of the return value of the replacement operation differs from +/// that of the replaced operation. This is not handled in tablegen-based +/// pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. +template <typename StdOp, typename SPIRVOp> +class IntegerOpConversion final : public SPIRVOpLowering<StdOp> { +public: + using SPIRVOpLowering<StdOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(StdOp operation, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto resultType = + this->typeConverter.convertType(operation.getResult()->getType()); + rewriter.template replaceOpWithNewOp<SPIRVOp>( + operation, resultType, operands, ArrayRef<NamedAttribute>()); + return this->matchSuccess(); + } +}; + +/// Convert load -> spv.LoadOp. The operands of the replaced operation are of +/// IndexType while that of the replacement operation are of type i32. This is +/// not supported in tablegen based pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. +class LoadOpConversion final : public SPIRVOpLowering<LoadOp> { +public: + using SPIRVOpLowering<LoadOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert return -> spv.Return. +// TODO(ravishankarm) : This should be moved into DRR. +class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> { +public: + using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert select -> spv.Select +// TODO(ravishankarm) : This should be moved into DRR. +class SelectOpConversion final : public SPIRVOpLowering<SelectOp> { +public: + using SPIRVOpLowering<SelectOp>::SPIRVOpLowering; + PatternMatchResult + matchAndRewrite(SelectOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert store -> spv.StoreOp. The operands of the replaced operation are +/// of IndexType while that of the replacement operation are of type i32. This +/// is not supported in tablegen based pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. +class StoreOpConversion final : public SPIRVOpLowering<StoreOp> { +public: + using SPIRVOpLowering<StoreOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Utility functions for operation conversion +//===----------------------------------------------------------------------===// + +/// Performs the index computation to get to the element pointed to by +/// `indices` using the layout map of `baseType`. + +// TODO(ravishankarm) : This method assumes that the `origBaseType` is a +// MemRefType with AffineMap that has static strides. Handle dynamic strides +spirv::AccessChainOp getElementPtr(OpBuilder &builder, + SPIRVTypeConverter &typeConverter, + Location loc, MemRefType origBaseType, + Value basePtr, ArrayRef<Value> indices) { + // Get base and offset of the MemRefType and verify they are static. + int64_t offset; + SmallVector<int64_t, 4> strides; + if (failed(getStridesAndOffset(origBaseType, strides, offset)) || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + return nullptr; + } + + auto indexType = typeConverter.getIndexType(builder.getContext()); + + Value ptrLoc = nullptr; + assert(indices.size() == strides.size()); + for (auto index : enumerate(indices)) { + Value strideVal = builder.create<spirv::ConstantOp>( + loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); + ptrLoc = + (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult() + : update); + } + SmallVector<Value, 2> linearizedIndices; + // Add a '0' at the start to index into the struct. + linearizedIndices.push_back(builder.create<spirv::ConstantOp>( + loc, indexType, IntegerAttr::get(indexType, 0))); + linearizedIndices.push_back(ptrLoc); + return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); +} + +//===----------------------------------------------------------------------===// +// ConstantOp with index type. +//===----------------------------------------------------------------------===// + +PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( + ConstantOp constIndexOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (!constIndexOp.getResult()->getType().isa<IndexType>()) { + return matchFailure(); + } + // The attribute has index type which is not directly supported in + // SPIR-V. Get the integer value and create a new IntegerAttr. + auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>(); + if (!constAttr) { + return matchFailure(); + } + + // Use the bitwidth set in the value attribute to decide the result type + // of the SPIR-V constant operation since SPIR-V does not support index + // types. + auto constVal = constAttr.getValue(); + auto constValType = constAttr.getType().dyn_cast<IndexType>(); + if (!constValType) { + return matchFailure(); + } + auto spirvConstType = + typeConverter.convertType(constIndexOp.getResult()->getType()); + auto spirvConstVal = + rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); + rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType, + spirvConstVal); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + CmpIOpOperandAdaptor cmpIOpOperands(operands); + + switch (cmpIOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp<spirvOp>( \ + cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \ + cmpIOpOperands.rhs()); \ + return matchSuccess(); + + DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + +#undef DISPATCH + + default: + break; + } + return matchFailure(); +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + LoadOpOperandAdaptor loadOperands(operands); + auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), + loadOp.memref()->getType().cast<MemRefType>(), + loadOperands.memref(), loadOperands.indices()); + rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, + /*memory_access =*/nullptr, + /*alignment =*/nullptr); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (returnOp.getNumOperands()) { + return matchFailure(); + } + rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + SelectOpOperandAdaptor selectOperands(operands); + rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(), + selectOperands.true_value(), + selectOperands.false_value()); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + StoreOpOperandAdaptor storeOperands(operands); + auto storePtr = + getElementPtr(rewriter, typeConverter, storeOp.getLoc(), + storeOp.memref()->getType().cast<MemRefType>(), + storeOperands.memref(), storeOperands.indices()); + rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, + storeOperands.value(), + /*memory_access =*/nullptr, + /*alignment =*/nullptr); + return matchSuccess(); +} + +namespace { +/// Import the Standard Ops to SPIR-V Patterns. +#include "StandardToSPIRV.cpp.inc" +} // namespace + +namespace mlir { +void populateStandardToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + // Add patterns that lower operations into SPIR-V dialect. + populateWithGenerated(context, &patterns); + patterns.insert<ConstantIndexOpConversion, CmpIOpConversion, + IntegerOpConversion<AddIOp, spirv::IAddOp>, + IntegerOpConversion<MulIOp, spirv::IMulOp>, + IntegerOpConversion<SignedDivIOp, spirv::SDivOp>, + IntegerOpConversion<SignedRemIOp, spirv::SModOp>, + IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion, + ReturnOpConversion, SelectOpConversion, StoreOpConversion>( + context, typeConverter); +} +} // namespace mlir diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp new file mode 100644 index 00000000000..52456b6e46d --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -0,0 +1,89 @@ +//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard ops into the SPIR-V +// ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +/// A simple pattern for rewriting function signature to convert arguments of +/// functions to be of valid SPIR-V types. +class FuncOpConversion final : public SPIRVOpLowering<FuncOp> { +public: + using SPIRVOpLowering<FuncOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// A pass converting MLIR Standard operations into the SPIR-V dialect. +class ConvertStandardToSPIRVPass + : public ModulePass<ConvertStandardToSPIRVPass> { + void runOnModule() override; +}; +} // namespace + +PatternMatchResult +FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + auto fnType = funcOp.getType(); + if (fnType.getNumResults()) { + return matchFailure(); + } + + TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); + { + for (auto argType : enumerate(funcOp.getType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + signatureConverter.addInputs(argType.index(), convertedType); + } + } + + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + }); + return matchSuccess(); +} + +void ConvertStandardToSPIRVPass::runOnModule() { + OwningRewritePatternList patterns; + auto context = &getContext(); + auto module = getModule(); + + SPIRVTypeConverter typeConverter; + populateStandardToSPIRVPatterns(context, typeConverter, patterns); + patterns.insert<FuncOpConversion>(context, typeConverter); + ConversionTarget target(*(module.getContext())); + target.addLegalDialect<spirv::SPIRVDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); + + if (failed(applyPartialConversion(module, target, patterns))) { + return signalPassFailure(); + } +} + +std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() { + return std::make_unique<ConvertStandardToSPIRVPass>(); +} + +static PassRegistration<ConvertStandardToSPIRVPass> + pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect"); diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp new file mode 100644 index 00000000000..a658356f76c --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -0,0 +1,181 @@ +//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes operations before the conversion to SPIR-V +// dialect to handle ops that cannot be lowered directly. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Merges subview operation with load operation. +class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> { +public: + using OpRewritePattern<LoadOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges subview operation with store operation. +class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> { +public: + using OpRewritePattern<StoreOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Utility functions for op legalization. +//===----------------------------------------------------------------------===// + +/// Given the 'indices' of an load/store operation where the memref is a result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to +/// memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// +/// could be folded into +/// +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndices(Location loc, PatternRewriter &rewriter, + SubViewOp subViewOp, ValueRange indices, + SmallVectorImpl<Value> &sourceIndices) { + // TODO: Aborting when the offsets are static. There might be a way to fold + // the subview op with load even if the offsets have been canonicalized + // away. + if (subViewOp.getNumOffsets() == 0) + return failure(); + + ValueRange opOffsets = subViewOp.offsets(); + SmallVector<Value, 2> opStrides; + if (subViewOp.getNumStrides()) { + // If the strides are dynamic, get the stride operands. + opStrides = llvm::to_vector<2>(subViewOp.strides()); + } else { + // When static, the stride operands can be retrieved by taking the strides + // of the result of the subview op, and dividing the strides of the base + // memref. + SmallVector<int64_t, 2> staticStrides; + if (failed(subViewOp.getStaticStrides(staticStrides))) { + return failure(); + } + opStrides.reserve(opOffsets.size()); + for (auto stride : staticStrides) { + auto constValAttr = rewriter.getIntegerAttr( + IndexType::get(rewriter.getContext()), stride); + opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr)); + } + } + assert(opOffsets.size() == opStrides.size()); + + // New indices for the load are the current indices * subview_stride + + // subview_offset. + assert(indices.size() == opStrides.size()); + sourceIndices.resize(indices.size()); + for (auto index : llvm::enumerate(indices)) { + auto offset = opOffsets[index.index()]; + auto stride = opStrides[index.index()]; + auto mul = rewriter.create<MulIOp>(loc, index.value(), stride); + sourceIndices[index.index()] = + rewriter.create<AddIOp>(loc, offset, mul).getResult(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and LoadOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector<Value, 4> sourceIndices; + if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, + loadOp.indices(), sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(), + sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and StoreOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector<Value, 4> sourceIndices; + if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, + storeOp.indices(), sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(), + subViewOp.source(), sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Hook for adding patterns. +//===----------------------------------------------------------------------===// + +void mlir::populateStdLegalizationPatternsForSPIRVLowering( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context); +} + +//===----------------------------------------------------------------------===// +// Pass for testing just the legalization patterns. +//===----------------------------------------------------------------------===// + +namespace { +struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> { + void runOnOperation() override; +}; +} // namespace + +void SPIRVLegalization::runOnOperation() { + OwningRewritePatternList patterns; + auto *context = &getContext(); + populateStdLegalizationPatternsForSPIRVLowering(context, patterns); + applyPatternsGreedily(getOperation()->getRegions(), patterns); +} + +std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() { + return std::make_unique<SPIRVLegalization>(); +} + +static PassRegistration<SPIRVLegalization> + pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering"); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td new file mode 100644 index 00000000000..6f3a6a82476 --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -0,0 +1,35 @@ +//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==// + +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines Patterns to lower standard ops to SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_TD +#define MLIR_CONVERSION_STANDARDTOSPIRV_TD + +include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/SPIRV/SPIRVOps.td" + +class BinaryOpPattern<Op src, Op tgt> : + Pat<(src SPV_ScalarOrVector:$l, SPV_ScalarOrVector:$r), + (tgt $l, $r)>; + +def : BinaryOpPattern<AddFOp, SPV_FAddOp>; +def : BinaryOpPattern<DivFOp, SPV_FDivOp>; +def : BinaryOpPattern<MulFOp, SPV_FMulOp>; +def : BinaryOpPattern<RemFOp, SPV_FRemOp>; +def : BinaryOpPattern<SubFOp, SPV_FSubOp>; + +// Constant Op +// TODO(ravishankarm): Handle lowering other constant types. +def : Pat<(ConstantOp:$result $valueAttr), + (SPV_ConstantOp $valueAttr), + [(SPV_ScalarOrVector $result)]>; + +#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD |