summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToSPIRV
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt26
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp314
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp89
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp181
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td35
5 files changed, 645 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
new file mode 100644
index 00000000000..fcced23a95e
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
+mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
+
+add_llvm_library(MLIRStandardToSPIRVTransforms
+ ConvertStandardToSPIRV.cpp
+ ConvertStandardToSPIRVPass.cpp
+ LegalizeStandardForSPIRV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+ )
+
+add_dependencies(MLIRStandardToSPIRVTransforms
+ MLIRStandardToSPIRVIncGen)
+
+target_link_libraries(MLIRStandardToSPIRVTransforms
+ MLIRIR
+ MLIRPass
+ MLIRSPIRV
+ MLIRSupport
+ MLIRTransformUtils
+ MLIRSPIRV
+ MLIRStandardOps
+ )
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
new file mode 100644
index 00000000000..a02dee4419a
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -0,0 +1,314 @@
+//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineMap.h"
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Convert constant operation with IndexType return to SPIR-V constant
+/// operation. Since IndexType is not used within SPIR-V dialect, this needs
+/// special handling to make sure the result type and the type of the value
+/// attribute are consistent.
+// TODO(ravishankarm) : This should be moved into DRR.
+class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
+public:
+ using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert compare operation to SPIR-V dialect.
+class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
+public:
+ using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert integer binary operations to SPIR-V operations. Cannot use
+/// tablegen for this. If the integer operation is on variables of IndexType,
+/// the type of the return value of the replacement operation differs from
+/// that of the replaced operation. This is not handled in tablegen-based
+/// pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
+template <typename StdOp, typename SPIRVOp>
+class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
+public:
+ using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType =
+ this->typeConverter.convertType(operation.getResult()->getType());
+ rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ operation, resultType, operands, ArrayRef<NamedAttribute>());
+ return this->matchSuccess();
+ }
+};
+
+/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
+/// IndexType while that of the replacement operation are of type i32. This is
+/// not supported in tablegen based pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
+class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
+public:
+ using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert return -> spv.Return.
+// TODO(ravishankarm) : This should be moved into DRR.
+class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
+public:
+ using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert select -> spv.Select
+// TODO(ravishankarm) : This should be moved into DRR.
+class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
+public:
+ using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
+ PatternMatchResult
+ matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert store -> spv.StoreOp. The operands of the replaced operation are
+/// of IndexType while that of the replacement operation are of type i32. This
+/// is not supported in tablegen based pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
+class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
+public:
+ using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for operation conversion
+//===----------------------------------------------------------------------===//
+
+/// Performs the index computation to get to the element pointed to by
+/// `indices` using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
+// MemRefType with AffineMap that has static strides. Handle dynamic strides
+spirv::AccessChainOp getElementPtr(OpBuilder &builder,
+ SPIRVTypeConverter &typeConverter,
+ Location loc, MemRefType origBaseType,
+ Value basePtr, ArrayRef<Value> indices) {
+ // Get base and offset of the MemRefType and verify they are static.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+ return nullptr;
+ }
+
+ auto indexType = typeConverter.getIndexType(builder.getContext());
+
+ Value ptrLoc = nullptr;
+ assert(indices.size() == strides.size());
+ for (auto index : enumerate(indices)) {
+ Value strideVal = builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+ Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ ptrLoc =
+ (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
+ : update);
+ }
+ SmallVector<Value, 2> linearizedIndices;
+ // Add a '0' at the start to index into the struct.
+ linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 0)));
+ linearizedIndices.push_back(ptrLoc);
+ return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+}
+
+//===----------------------------------------------------------------------===//
+// ConstantOp with index type.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
+ ConstantOp constIndexOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
+ return matchFailure();
+ }
+ // The attribute has index type which is not directly supported in
+ // SPIR-V. Get the integer value and create a new IntegerAttr.
+ auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
+ if (!constAttr) {
+ return matchFailure();
+ }
+
+ // Use the bitwidth set in the value attribute to decide the result type
+ // of the SPIR-V constant operation since SPIR-V does not support index
+ // types.
+ auto constVal = constAttr.getValue();
+ auto constValType = constAttr.getType().dyn_cast<IndexType>();
+ if (!constValType) {
+ return matchFailure();
+ }
+ auto spirvConstType =
+ typeConverter.convertType(constIndexOp.getResult()->getType());
+ auto spirvConstVal =
+ rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
+ spirvConstVal);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ CmpIOpOperandAdaptor cmpIOpOperands(operands);
+
+ switch (cmpIOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp) \
+ case cmpPredicate: \
+ rewriter.replaceOpWithNewOp<spirvOp>( \
+ cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \
+ cmpIOpOperands.rhs()); \
+ return matchSuccess();
+
+ DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
+ DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
+ DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
+ DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
+ DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
+ DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
+
+#undef DISPATCH
+
+ default:
+ break;
+ }
+ return matchFailure();
+}
+
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ LoadOpOperandAdaptor loadOperands(operands);
+ auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
+ loadOp.memref()->getType().cast<MemRefType>(),
+ loadOperands.memref(), loadOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
+ /*memory_access =*/nullptr,
+ /*alignment =*/nullptr);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (returnOp.getNumOperands()) {
+ return matchFailure();
+ }
+ rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ SelectOpOperandAdaptor selectOperands(operands);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
+ selectOperands.true_value(),
+ selectOperands.false_value());
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ StoreOpOperandAdaptor storeOperands(operands);
+ auto storePtr =
+ getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
+ storeOp.memref()->getType().cast<MemRefType>(),
+ storeOperands.memref(), storeOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
+ storeOperands.value(),
+ /*memory_access =*/nullptr,
+ /*alignment =*/nullptr);
+ return matchSuccess();
+}
+
+namespace {
+/// Import the Standard Ops to SPIR-V Patterns.
+#include "StandardToSPIRV.cpp.inc"
+} // namespace
+
+namespace mlir {
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ // Add patterns that lower operations into SPIR-V dialect.
+ populateWithGenerated(context, &patterns);
+ patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
+ IntegerOpConversion<AddIOp, spirv::IAddOp>,
+ IntegerOpConversion<MulIOp, spirv::IMulOp>,
+ IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
+ IntegerOpConversion<SignedRemIOp, spirv::SModOp>,
+ IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
+ ReturnOpConversion, SelectOpConversion, StoreOpConversion>(
+ context, typeConverter);
+}
+} // namespace mlir
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
new file mode 100644
index 00000000000..52456b6e46d
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -0,0 +1,89 @@
+//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard ops into the SPIR-V
+// ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// A simple pattern for rewriting function signature to convert arguments of
+/// functions to be of valid SPIR-V types.
+class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
+public:
+ using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// A pass converting MLIR Standard operations into the SPIR-V dialect.
+class ConvertStandardToSPIRVPass
+ : public ModulePass<ConvertStandardToSPIRVPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+PatternMatchResult
+FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto fnType = funcOp.getType();
+ if (fnType.getNumResults()) {
+ return matchFailure();
+ }
+
+ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ {
+ for (auto argType : enumerate(funcOp.getType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
+ }
+
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
+ });
+ return matchSuccess();
+}
+
+void ConvertStandardToSPIRVPass::runOnModule() {
+ OwningRewritePatternList patterns;
+ auto context = &getContext();
+ auto module = getModule();
+
+ SPIRVTypeConverter typeConverter;
+ populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+ patterns.insert<FuncOpConversion>(context, typeConverter);
+ ConversionTarget target(*(module.getContext()));
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
+
+ if (failed(applyPartialConversion(module, target, patterns))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() {
+ return std::make_unique<ConvertStandardToSPIRVPass>();
+}
+
+static PassRegistration<ConvertStandardToSPIRVPass>
+ pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
new file mode 100644
index 00000000000..a658356f76c
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -0,0 +1,181 @@
+//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes operations before the conversion to SPIR-V
+// dialect to handle ops that cannot be lowered directly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Merges subview operation with load operation.
+class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
+public:
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store operation.
+class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
+public:
+ using OpRewritePattern<StoreOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for op legalization.
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+/// memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+ SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ // TODO: Aborting when the offsets are static. There might be a way to fold
+ // the subview op with load even if the offsets have been canonicalized
+ // away.
+ if (subViewOp.getNumOffsets() == 0)
+ return failure();
+
+ ValueRange opOffsets = subViewOp.offsets();
+ SmallVector<Value, 2> opStrides;
+ if (subViewOp.getNumStrides()) {
+ // If the strides are dynamic, get the stride operands.
+ opStrides = llvm::to_vector<2>(subViewOp.strides());
+ } else {
+ // When static, the stride operands can be retrieved by taking the strides
+ // of the result of the subview op, and dividing the strides of the base
+ // memref.
+ SmallVector<int64_t, 2> staticStrides;
+ if (failed(subViewOp.getStaticStrides(staticStrides))) {
+ return failure();
+ }
+ opStrides.reserve(opOffsets.size());
+ for (auto stride : staticStrides) {
+ auto constValAttr = rewriter.getIntegerAttr(
+ IndexType::get(rewriter.getContext()), stride);
+ opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr));
+ }
+ }
+ assert(opOffsets.size() == opStrides.size());
+
+ // New indices for the load are the current indices * subview_stride +
+ // subview_offset.
+ assert(indices.size() == opStrides.size());
+ sourceIndices.resize(indices.size());
+ for (auto index : llvm::enumerate(indices)) {
+ auto offset = opOffsets[index.index()];
+ auto stride = opStrides[index.index()];
+ auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
+ sourceIndices[index.index()] =
+ rewriter.create<AddIOp>(loc, offset, mul).getResult();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and LoadOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
+ loadOp.indices(), sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
+ sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and StoreOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
+ storeOp.indices(), sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
+ subViewOp.source(), sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Hook for adding patterns.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateStdLegalizationPatternsForSPIRVLowering(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass for testing just the legalization patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void SPIRVLegalization::runOnOperation() {
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+ applyPatternsGreedily(getOperation()->getRegions(), patterns);
+}
+
+std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
+ return std::make_unique<SPIRVLegalization>();
+}
+
+static PassRegistration<SPIRVLegalization>
+ pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering");
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
new file mode 100644
index 00000000000..6f3a6a82476
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
@@ -0,0 +1,35 @@
+//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==//
+
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower standard ops to SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_TD
+#define MLIR_CONVERSION_STANDARDTOSPIRV_TD
+
+include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/SPIRV/SPIRVOps.td"
+
+class BinaryOpPattern<Op src, Op tgt> :
+ Pat<(src SPV_ScalarOrVector:$l, SPV_ScalarOrVector:$r),
+ (tgt $l, $r)>;
+
+def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
+def : BinaryOpPattern<DivFOp, SPV_FDivOp>;
+def : BinaryOpPattern<MulFOp, SPV_FMulOp>;
+def : BinaryOpPattern<RemFOp, SPV_FRemOp>;
+def : BinaryOpPattern<SubFOp, SPV_FSubOp>;
+
+// Constant Op
+// TODO(ravishankarm): Handle lowering other constant types.
+def : Pat<(ConstantOp:$result $valueAttr),
+ (SPV_ConstantOp $valueAttr),
+ [(SPV_ScalarOrVector $result)]>;
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
OpenPOWER on IntegriCloud