diff options
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp new file mode 100644 index 00000000000..a658356f76c --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -0,0 +1,181 @@ +//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes operations before the conversion to SPIR-V +// dialect to handle ops that cannot be lowered directly. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Merges subview operation with load operation. +class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> { +public: + using OpRewritePattern<LoadOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges subview operation with store operation. +class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> { +public: + using OpRewritePattern<StoreOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Utility functions for op legalization. +//===----------------------------------------------------------------------===// + +/// Given the 'indices' of an load/store operation where the memref is a result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to +/// memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// +/// could be folded into +/// +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndices(Location loc, PatternRewriter &rewriter, + SubViewOp subViewOp, ValueRange indices, + SmallVectorImpl<Value> &sourceIndices) { + // TODO: Aborting when the offsets are static. There might be a way to fold + // the subview op with load even if the offsets have been canonicalized + // away. + if (subViewOp.getNumOffsets() == 0) + return failure(); + + ValueRange opOffsets = subViewOp.offsets(); + SmallVector<Value, 2> opStrides; + if (subViewOp.getNumStrides()) { + // If the strides are dynamic, get the stride operands. + opStrides = llvm::to_vector<2>(subViewOp.strides()); + } else { + // When static, the stride operands can be retrieved by taking the strides + // of the result of the subview op, and dividing the strides of the base + // memref. + SmallVector<int64_t, 2> staticStrides; + if (failed(subViewOp.getStaticStrides(staticStrides))) { + return failure(); + } + opStrides.reserve(opOffsets.size()); + for (auto stride : staticStrides) { + auto constValAttr = rewriter.getIntegerAttr( + IndexType::get(rewriter.getContext()), stride); + opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr)); + } + } + assert(opOffsets.size() == opStrides.size()); + + // New indices for the load are the current indices * subview_stride + + // subview_offset. + assert(indices.size() == opStrides.size()); + sourceIndices.resize(indices.size()); + for (auto index : llvm::enumerate(indices)) { + auto offset = opOffsets[index.index()]; + auto stride = opStrides[index.index()]; + auto mul = rewriter.create<MulIOp>(loc, index.value(), stride); + sourceIndices[index.index()] = + rewriter.create<AddIOp>(loc, offset, mul).getResult(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and LoadOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector<Value, 4> sourceIndices; + if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, + loadOp.indices(), sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(), + sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and StoreOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector<Value, 4> sourceIndices; + if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, + storeOp.indices(), sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(), + subViewOp.source(), sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Hook for adding patterns. +//===----------------------------------------------------------------------===// + +void mlir::populateStdLegalizationPatternsForSPIRVLowering( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context); +} + +//===----------------------------------------------------------------------===// +// Pass for testing just the legalization patterns. +//===----------------------------------------------------------------------===// + +namespace { +struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> { + void runOnOperation() override; +}; +} // namespace + +void SPIRVLegalization::runOnOperation() { + OwningRewritePatternList patterns; + auto *context = &getContext(); + populateStdLegalizationPatternsForSPIRVLowering(context, patterns); + applyPatternsGreedily(getOperation()->getRegions(), patterns); +} + +std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() { + return std::make_unique<SPIRVLegalization>(); +} + +static PassRegistration<SPIRVLegalization> + pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering"); |