summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp181
1 files changed, 181 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
new file mode 100644
index 00000000000..a658356f76c
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -0,0 +1,181 @@
+//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes operations before the conversion to SPIR-V
+// dialect to handle ops that cannot be lowered directly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Merges subview operation with load operation.
+class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
+public:
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store operation.
+class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
+public:
+ using OpRewritePattern<StoreOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for op legalization.
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+/// memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+ SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ // TODO: Aborting when the offsets are static. There might be a way to fold
+ // the subview op with load even if the offsets have been canonicalized
+ // away.
+ if (subViewOp.getNumOffsets() == 0)
+ return failure();
+
+ ValueRange opOffsets = subViewOp.offsets();
+ SmallVector<Value, 2> opStrides;
+ if (subViewOp.getNumStrides()) {
+ // If the strides are dynamic, get the stride operands.
+ opStrides = llvm::to_vector<2>(subViewOp.strides());
+ } else {
+ // When static, the stride operands can be retrieved by taking the strides
+ // of the result of the subview op, and dividing the strides of the base
+ // memref.
+ SmallVector<int64_t, 2> staticStrides;
+ if (failed(subViewOp.getStaticStrides(staticStrides))) {
+ return failure();
+ }
+ opStrides.reserve(opOffsets.size());
+ for (auto stride : staticStrides) {
+ auto constValAttr = rewriter.getIntegerAttr(
+ IndexType::get(rewriter.getContext()), stride);
+ opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr));
+ }
+ }
+ assert(opOffsets.size() == opStrides.size());
+
+ // New indices for the load are the current indices * subview_stride +
+ // subview_offset.
+ assert(indices.size() == opStrides.size());
+ sourceIndices.resize(indices.size());
+ for (auto index : llvm::enumerate(indices)) {
+ auto offset = opOffsets[index.index()];
+ auto stride = opStrides[index.index()];
+ auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
+ sourceIndices[index.index()] =
+ rewriter.create<AddIOp>(loc, offset, mul).getResult();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and LoadOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
+ loadOp.indices(), sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
+ sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and StoreOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
+ storeOp.indices(), sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
+ subViewOp.source(), sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Hook for adding patterns.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateStdLegalizationPatternsForSPIRVLowering(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass for testing just the legalization patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void SPIRVLegalization::runOnOperation() {
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+ applyPatternsGreedily(getOperation()->getRegions(), patterns);
+}
+
+std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
+ return std::make_unique<SPIRVLegalization>();
+}
+
+static PassRegistration<SPIRVLegalization>
+ pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering");
OpenPOWER on IntegriCloud