//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transformation pass legalizes operations before the conversion to SPIR-V // dialect to handle ops that cannot be lowered directly. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" using namespace mlir; namespace { /// Merges subview operation with load operation. class LoadOpOfSubViewFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(LoadOp loadOp, PatternRewriter &rewriter) const override; }; /// Merges subview operation with store operation. class StoreOpOfSubViewFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(StoreOp storeOp, PatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // Utility functions for op legalization. //===----------------------------------------------------------------------===// /// Given the 'indices' of an load/store operation where the memref is a result /// of a subview op, returns the indices w.r.t to the source memref of the /// subview op. For example /// /// %0 = ... : memref<12x42xf32> /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to /// memref<4x4xf32, offset=?, strides=[?, ?]> /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> /// /// could be folded into /// /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : /// memref<12x42xf32> static LogicalResult resolveSourceIndices(Location loc, PatternRewriter &rewriter, SubViewOp subViewOp, ValueRange indices, SmallVectorImpl &sourceIndices) { // TODO: Aborting when the offsets are static. There might be a way to fold // the subview op with load even if the offsets have been canonicalized // away. if (subViewOp.getNumOffsets() == 0) return failure(); ValueRange opOffsets = subViewOp.offsets(); SmallVector opStrides; if (subViewOp.getNumStrides()) { // If the strides are dynamic, get the stride operands. opStrides = llvm::to_vector<2>(subViewOp.strides()); } else { // When static, the stride operands can be retrieved by taking the strides // of the result of the subview op, and dividing the strides of the base // memref. SmallVector staticStrides; if (failed(subViewOp.getStaticStrides(staticStrides))) { return failure(); } opStrides.reserve(opOffsets.size()); for (auto stride : staticStrides) { auto constValAttr = rewriter.getIntegerAttr( IndexType::get(rewriter.getContext()), stride); opStrides.emplace_back(rewriter.create(loc, constValAttr)); } } assert(opOffsets.size() == opStrides.size()); // New indices for the load are the current indices * subview_stride + // subview_offset. assert(indices.size() == opStrides.size()); sourceIndices.resize(indices.size()); for (auto index : llvm::enumerate(indices)) { auto offset = opOffsets[index.index()]; auto stride = opStrides[index.index()]; auto mul = rewriter.create(loc, index.value(), stride); sourceIndices[index.index()] = rewriter.create(loc, offset, mul).getResult(); } return success(); } //===----------------------------------------------------------------------===// // Folding SubViewOp and LoadOp. //===----------------------------------------------------------------------===// PatternMatchResult LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, PatternRewriter &rewriter) const { auto subViewOp = dyn_cast_or_null(loadOp.memref().getDefiningOp()); if (!subViewOp) { return matchFailure(); } SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, loadOp.indices(), sourceIndices))) return matchFailure(); rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); return matchSuccess(); } //===----------------------------------------------------------------------===// // Folding SubViewOp and StoreOp. //===----------------------------------------------------------------------===// PatternMatchResult StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, PatternRewriter &rewriter) const { auto subViewOp = dyn_cast_or_null(storeOp.memref().getDefiningOp()); if (!subViewOp) { return matchFailure(); } SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, storeOp.indices(), sourceIndices))) return matchFailure(); rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), subViewOp.source(), sourceIndices); return matchSuccess(); } //===----------------------------------------------------------------------===// // Hook for adding patterns. //===----------------------------------------------------------------------===// void mlir::populateStdLegalizationPatternsForSPIRVLowering( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert(context); } //===----------------------------------------------------------------------===// // Pass for testing just the legalization patterns. //===----------------------------------------------------------------------===// namespace { struct SPIRVLegalization final : public OperationPass { void runOnOperation() override; }; } // namespace void SPIRVLegalization::runOnOperation() { OwningRewritePatternList patterns; auto *context = &getContext(); populateStdLegalizationPatternsForSPIRVLowering(context, patterns); applyPatternsGreedily(getOperation()->getRegions(), patterns); } std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { return std::make_unique(); } static PassRegistration pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering");