summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp550
-rw-r--r--mlir/lib/Conversion/AffineToStandard/CMakeLists.txt24
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt12
-rw-r--r--mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h85
-rw-r--r--mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h100
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt16
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp167
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp424
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt18
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td21
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp751
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt10
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp75
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt15
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp359
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp96
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt15
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp549
-rw-r--r--mlir/lib/Conversion/LoopToStandard/CMakeLists.txt22
-rw-r--r--mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp269
-rw-r--r--mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt21
-rw-r--r--mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp528
-rw-r--r--mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp147
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt24
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp2278
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt26
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp314
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp89
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp181
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td35
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt15
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp766
-rw-r--r--mlir/lib/Conversion/VectorToLoops/CMakeLists.txt15
-rw-r--r--mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp358
34 files changed, 8375 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
new file mode 100644
index 00000000000..e9a9ca82f51
--- /dev/null
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -0,0 +1,550 @@
+//===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file lowers affine constructs (If and For statements, AffineApply
+// operations) within a function into their standard If and For equivalent ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+// Visit affine expressions recursively and build the sequence of operations
+// that correspond to it. Visitation functions return an Value of the
+// expression subtree they visited or `nullptr` on error.
+class AffineApplyExpander
+ : public AffineExprVisitor<AffineApplyExpander, Value> {
+public:
+ // This internal class expects arguments to be non-null, checks must be
+ // performed at the call site.
+ AffineApplyExpander(OpBuilder &builder, ArrayRef<Value> dimValues,
+ ArrayRef<Value> symbolValues, Location loc)
+ : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
+ loc(loc) {}
+
+ template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) {
+ auto lhs = visit(expr.getLHS());
+ auto rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return nullptr;
+ auto op = builder.create<OpTy>(loc, lhs, rhs);
+ return op.getResult();
+ }
+
+ Value visitAddExpr(AffineBinaryOpExpr expr) {
+ return buildBinaryExpr<AddIOp>(expr);
+ }
+
+ Value visitMulExpr(AffineBinaryOpExpr expr) {
+ return buildBinaryExpr<MulIOp>(expr);
+ }
+
+ // Euclidean modulo operation: negative RHS is not allowed.
+ // Remainder of the euclidean integer division is always non-negative.
+ //
+ // Implemented as
+ //
+ // a mod b =
+ // let remainder = srem a, b;
+ // negative = a < 0 in
+ // select negative, remainder + b, remainder.
+ Value visitModExpr(AffineBinaryOpExpr expr) {
+ auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!rhsConst) {
+ emitError(
+ loc,
+ "semi-affine expressions (modulo by non-const) are not supported");
+ return nullptr;
+ }
+ if (rhsConst.getValue() <= 0) {
+ emitError(loc, "modulo by non-positive value is not supported");
+ return nullptr;
+ }
+
+ auto lhs = visit(expr.getLHS());
+ auto rhs = visit(expr.getRHS());
+ assert(lhs && rhs && "unexpected affine expr lowering failure");
+
+ Value remainder = builder.create<SignedRemIOp>(loc, lhs, rhs);
+ Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+ Value isRemainderNegative =
+ builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
+ Value correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
+ Value result = builder.create<SelectOp>(loc, isRemainderNegative,
+ correctedRemainder, remainder);
+ return result;
+ }
+
+ // Floor division operation (rounds towards negative infinity).
+ //
+ // For positive divisors, it can be implemented without branching and with a
+ // single division operation as
+ //
+ // a floordiv b =
+ // let negative = a < 0 in
+ // let absolute = negative ? -a - 1 : a in
+ // let quotient = absolute / b in
+ // negative ? -quotient - 1 : quotient
+ Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!rhsConst) {
+ emitError(
+ loc,
+ "semi-affine expressions (division by non-const) are not supported");
+ return nullptr;
+ }
+ if (rhsConst.getValue() <= 0) {
+ emitError(loc, "division by non-positive value is not supported");
+ return nullptr;
+ }
+
+ auto lhs = visit(expr.getLHS());
+ auto rhs = visit(expr.getRHS());
+ assert(lhs && rhs && "unexpected affine expr lowering failure");
+
+ Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+ Value noneCst = builder.create<ConstantIndexOp>(loc, -1);
+ Value negative =
+ builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
+ Value negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
+ Value dividend =
+ builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
+ Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
+ Value correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
+ Value result =
+ builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
+ return result;
+ }
+
+ // Ceiling division operation (rounds towards positive infinity).
+ //
+ // For positive divisors, it can be implemented without branching and with a
+ // single division operation as
+ //
+ // a ceildiv b =
+ // let negative = a <= 0 in
+ // let absolute = negative ? -a : a - 1 in
+ // let quotient = absolute / b in
+ // negative ? -quotient : quotient + 1
+ Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!rhsConst) {
+ emitError(loc) << "semi-affine expressions (division by non-const) are "
+ "not supported";
+ return nullptr;
+ }
+ if (rhsConst.getValue() <= 0) {
+ emitError(loc, "division by non-positive value is not supported");
+ return nullptr;
+ }
+ auto lhs = visit(expr.getLHS());
+ auto rhs = visit(expr.getRHS());
+ assert(lhs && rhs && "unexpected affine expr lowering failure");
+
+ Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+ Value oneCst = builder.create<ConstantIndexOp>(loc, 1);
+ Value nonPositive =
+ builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
+ Value negated = builder.create<SubIOp>(loc, zeroCst, lhs);
+ Value decremented = builder.create<SubIOp>(loc, lhs, oneCst);
+ Value dividend =
+ builder.create<SelectOp>(loc, nonPositive, negated, decremented);
+ Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
+ Value negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
+ Value incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
+ Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
+ incrementedQuotient);
+ return result;
+ }
+
+ Value visitConstantExpr(AffineConstantExpr expr) {
+ auto valueAttr =
+ builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
+ auto op =
+ builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr);
+ return op.getResult();
+ }
+
+ Value visitDimExpr(AffineDimExpr expr) {
+ assert(expr.getPosition() < dimValues.size() &&
+ "affine dim position out of range");
+ return dimValues[expr.getPosition()];
+ }
+
+ Value visitSymbolExpr(AffineSymbolExpr expr) {
+ assert(expr.getPosition() < symbolValues.size() &&
+ "symbol dim position out of range");
+ return symbolValues[expr.getPosition()];
+ }
+
+private:
+ OpBuilder &builder;
+ ArrayRef<Value> dimValues;
+ ArrayRef<Value> symbolValues;
+
+ Location loc;
+};
+} // namespace
+
+// Create a sequence of operations that implement the `expr` applied to the
+// given dimension and symbol values.
+mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
+ AffineExpr expr, ArrayRef<Value> dimValues,
+ ArrayRef<Value> symbolValues) {
+ return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
+}
+
+// Create a sequence of operations that implement the `affineMap` applied to
+// the given `operands` (as it it were an AffineApplyOp).
+Optional<SmallVector<Value, 8>> static expandAffineMap(
+ OpBuilder &builder, Location loc, AffineMap affineMap,
+ ArrayRef<Value> operands) {
+ auto numDims = affineMap.getNumDims();
+ auto expanded = functional::map(
+ [numDims, &builder, loc, operands](AffineExpr expr) {
+ return expandAffineExpr(builder, loc, expr,
+ operands.take_front(numDims),
+ operands.drop_front(numDims));
+ },
+ affineMap.getResults());
+ if (llvm::all_of(expanded, [](Value v) { return v; }))
+ return expanded;
+ return None;
+}
+
+// Given a range of values, emit the code that reduces them with "min" or "max"
+// depending on the provided comparison predicate. The predicate defines which
+// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
+// `cmpi` operation followed by the `select` operation:
+//
+// %cond = cmpi "predicate" %v0, %v1
+// %result = select %cond, %v0, %v1
+//
+// Multiple values are scanned in a linear sequence. This creates a data
+// dependences that wouldn't exist in a tree reduction, but is easier to
+// recognize as a reduction by the subsequent passes.
+static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
+ ArrayRef<Value> values,
+ OpBuilder &builder) {
+ assert(!llvm::empty(values) && "empty min/max chain");
+
+ auto valueIt = values.begin();
+ Value value = *valueIt++;
+ for (; valueIt != values.end(); ++valueIt) {
+ auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
+ value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
+ }
+
+ return value;
+}
+
+// Emit instructions that correspond to the affine map in the lower bound
+// applied to the respective operands, and compute the maximum value across
+// the results.
+Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
+ SmallVector<Value, 8> boundOperands(op.getLowerBoundOperands());
+ auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(),
+ boundOperands);
+ if (!lbValues)
+ return nullptr;
+ return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues,
+ builder);
+}
+
+// Emit instructions that correspond to the affine map in the upper bound
+// applied to the respective operands, and compute the minimum value across
+// the results.
+Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
+ SmallVector<Value, 8> boundOperands(op.getUpperBoundOperands());
+ auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(),
+ boundOperands);
+ if (!ubValues)
+ return nullptr;
+ return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues,
+ builder);
+}
+
+namespace {
+// Affine terminators are removed.
+class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
+public:
+ using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineTerminatorOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<loop::TerminatorOp>(op);
+ return matchSuccess();
+ }
+};
+
+class AffineForLowering : public OpRewritePattern<AffineForOp> {
+public:
+ using OpRewritePattern<AffineForOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineForOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value lowerBound = lowerAffineLowerBound(op, rewriter);
+ Value upperBound = lowerAffineUpperBound(op, rewriter);
+ Value step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
+ auto f = rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step);
+ f.region().getBlocks().clear();
+ rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+
+class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
+public:
+ using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineIfOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ // Now we just have to handle the condition logic.
+ auto integerSet = op.getIntegerSet();
+ Value zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
+ SmallVector<Value, 8> operands(op.getOperands());
+ auto operandsRef = llvm::makeArrayRef(operands);
+
+ // Calculate cond as a conjunction without short-circuiting.
+ Value cond = nullptr;
+ for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
+ AffineExpr constraintExpr = integerSet.getConstraint(i);
+ bool isEquality = integerSet.isEq(i);
+
+ // Build and apply an affine expression
+ auto numDims = integerSet.getNumDims();
+ Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
+ operandsRef.take_front(numDims),
+ operandsRef.drop_front(numDims));
+ if (!affResult)
+ return matchFailure();
+ auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
+ Value cmpVal =
+ rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
+ cond =
+ cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal;
+ }
+ cond = cond ? cond
+ : rewriter.create<ConstantIntOp>(loc, /*value=*/1, /*width=*/1);
+
+ bool hasElseRegion = !op.elseRegion().empty();
+ auto ifOp = rewriter.create<loop::IfOp>(loc, cond, hasElseRegion);
+ rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back());
+ ifOp.thenRegion().back().erase();
+ if (hasElseRegion) {
+ rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back());
+ ifOp.elseRegion().back().erase();
+ }
+
+ // Ok, we're done!
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+
+// Convert an "affine.apply" operation into a sequence of arithmetic
+// operations using the StandardOps dialect.
+class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
+public:
+ using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineApplyOp op,
+ PatternRewriter &rewriter) const override {
+ auto maybeExpandedMap =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
+ llvm::to_vector<8>(op.getOperands()));
+ if (!maybeExpandedMap)
+ return matchFailure();
+ rewriter.replaceOp(op, *maybeExpandedMap);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.load' operation to its operands, and
+// feed the results to a newly created 'std.load' operation (which replaces the
+// original 'affine.load').
+class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
+public:
+ using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineLoadOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map from 'affineLoadOp'.
+ SmallVector<Value, 8> indices(op.getMapOperands());
+ auto resultOperands =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+ if (!resultOperands)
+ return matchFailure();
+
+ // Build std.load memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.prefetch' operation to its operands, and
+// feed the results to a newly created 'std.prefetch' operation (which replaces
+// the original 'affine.prefetch').
+class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
+public:
+ using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map from 'affinePrefetchOp'.
+ SmallVector<Value, 8> indices(op.getMapOperands());
+ auto resultOperands =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+ if (!resultOperands)
+ return matchFailure();
+
+ // Build std.prefetch memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<PrefetchOp>(
+ op, op.memref(), *resultOperands, op.isWrite(),
+ op.localityHint().getZExtValue(), op.isDataCache());
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.store' operation to its operands, and
+// feed the results to a newly created 'std.store' operation (which replaces the
+// original 'affine.store').
+class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
+public:
+ using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineStoreOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map from 'affineStoreOp'.
+ SmallVector<Value, 8> indices(op.getMapOperands());
+ auto maybeExpandedMap =
+ expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+ if (!maybeExpandedMap)
+ return matchFailure();
+
+ // Build std.store valueToStore, memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
+ op.getMemRef(), *maybeExpandedMap);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine maps from an 'affine.dma_start' operation to each of their
+// respective map operands, and feed the results to a newly created
+// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
+class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
+public:
+ using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineDmaStartOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value, 8> operands(op.getOperands());
+ auto operandsRef = llvm::makeArrayRef(operands);
+
+ // Expand affine map for DMA source memref.
+ auto maybeExpandedSrcMap = expandAffineMap(
+ rewriter, op.getLoc(), op.getSrcMap(),
+ operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
+ if (!maybeExpandedSrcMap)
+ return matchFailure();
+ // Expand affine map for DMA destination memref.
+ auto maybeExpandedDstMap = expandAffineMap(
+ rewriter, op.getLoc(), op.getDstMap(),
+ operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
+ if (!maybeExpandedDstMap)
+ return matchFailure();
+ // Expand affine map for DMA tag memref.
+ auto maybeExpandedTagMap = expandAffineMap(
+ rewriter, op.getLoc(), op.getTagMap(),
+ operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
+ if (!maybeExpandedTagMap)
+ return matchFailure();
+
+ // Build std.dma_start operation with affine map results.
+ rewriter.replaceOpWithNewOp<DmaStartOp>(
+ op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
+ *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
+ *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.dma_wait' operation tag memref,
+// and feed the results to a newly created 'std.dma_wait' operation (which
+// replaces the original 'affine.dma_wait').
+class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
+public:
+ using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineDmaWaitOp op,
+ PatternRewriter &rewriter) const override {
+ // Expand affine map for DMA tag memref.
+ SmallVector<Value, 8> indices(op.getTagIndices());
+ auto maybeExpandedTagMap =
+ expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
+ if (!maybeExpandedTagMap)
+ return matchFailure();
+
+ // Build std.dma_wait operation with affine map results.
+ rewriter.replaceOpWithNewOp<DmaWaitOp>(
+ op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
+ return matchSuccess();
+ }
+};
+
+} // end namespace
+
+void mlir::populateAffineToStdConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<
+ AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering,
+ AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering,
+ AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx);
+}
+
+namespace {
+class LowerAffinePass : public FunctionPass<LowerAffinePass> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateAffineToStdConversionPatterns(patterns, &getContext());
+ ConversionTarget target(getContext());
+ target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
+ if (failed(applyPartialConversion(getFunction(), target, patterns)))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+/// Lowers If and For operations within a function into their lower level CFG
+/// equivalent blocks.
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createLowerAffinePass() {
+ return std::make_unique<LowerAffinePass>();
+}
+
+static PassRegistration<LowerAffinePass>
+ pass("lower-affine",
+ "Lower If, For, AffineApply operations to primitive equivalents");
diff --git a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt
new file mode 100644
index 00000000000..33f7db7abc4
--- /dev/null
+++ b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_llvm_library(MLIRAffineToStandard
+ AffineToStandard.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AffineToStandard
+)
+add_dependencies(
+ MLIRAffineToStandard
+
+ MLIRAffineOps
+ MLIRStandardOps
+ MLIRIR
+ LLVMCore
+ LLVMSupport
+)
+target_link_libraries(
+ MLIRAffineToStandard
+
+ MLIRAffineOps
+ MLIRStandardOps
+ MLIRIR
+ LLVMCore
+ LLVMSupport
+)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
new file mode 100644
index 00000000000..c791d214d30
--- /dev/null
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_subdirectory(AffineToStandard)
+add_subdirectory(GPUToCUDA)
+add_subdirectory(GPUToNVVM)
+add_subdirectory(GPUToROCDL)
+add_subdirectory(GPUToSPIRV)
+add_subdirectory(LinalgToLLVM)
+add_subdirectory(LoopsToGPU)
+add_subdirectory(LoopToStandard)
+add_subdirectory(StandardToLLVM)
+add_subdirectory(StandardToSPIRV)
+add_subdirectory(VectorToLLVM)
+add_subdirectory(VectorToLoops)
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
new file mode 100644
index 00000000000..63bc15173be
--- /dev/null
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -0,0 +1,85 @@
+//===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
+#define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+
+#include "llvm/ADT/StringSwitch.h"
+
+namespace mlir {
+
+// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
+// that Op operates on. Op is assumed to return an `std.index` value and
+// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
+// `indexBitwidth`, sign-extend or truncate the resulting value to match the
+// bitwidth expected by the consumers of the value.
+template <typename Op, typename XOp, typename YOp, typename ZOp>
+struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering {
+private:
+ enum dimension { X = 0, Y = 1, Z = 2, invalid };
+ unsigned indexBitwidth;
+
+ static dimension dimensionToIndex(Op op) {
+ return llvm::StringSwitch<dimension>(op.dimension())
+ .Case("x", X)
+ .Case("y", Y)
+ .Case("z", Z)
+ .Default(invalid);
+ }
+
+ static unsigned getIndexBitWidth(LLVMTypeConverter &type_converter) {
+ auto dialect = type_converter.getDialect();
+ return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+ }
+
+public:
+ explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(Op::getOperationName(),
+ lowering_.getDialect()->getContext(), lowering_),
+ indexBitwidth(getIndexBitWidth(lowering_)) {}
+
+ // Convert the kernel arguments to an LLVM type, preserve the rest.
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto dialect = lowering.getDialect();
+ Value newOp;
+ switch (dimensionToIndex(cast<Op>(op))) {
+ case X:
+ newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ break;
+ case Y:
+ newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ break;
+ case Z:
+ newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ break;
+ default:
+ return matchFailure();
+ }
+
+ if (indexBitwidth > 32) {
+ newOp = rewriter.create<LLVM::SExtOp>(
+ loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+ } else if (indexBitwidth < 32) {
+ newOp = rewriter.create<LLVM::TruncOp>(
+ loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+ }
+
+ rewriter.replaceOp(op, {newOp});
+ return matchSuccess();
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
new file mode 100644
index 00000000000..b75c1bf2d7b
--- /dev/null
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -0,0 +1,100 @@
+//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
+#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+
+namespace mlir {
+
+/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
+/// depending on the element type that Op operates upon. The function
+/// declaration is added in case it was not added before.
+///
+/// Example with NVVM:
+/// %exp_f32 = std.exp %arg_f32 : f32
+///
+/// will be transformed into
+/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
+template <typename SourceOp>
+struct OpToFuncCallLowering : public LLVMOpLowering {
+public:
+ explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
+ StringRef f64Func)
+ : LLVMOpLowering(SourceOp::getOperationName(),
+ lowering_.getDialect()->getContext(), lowering_),
+ f32Func(f32Func), f64Func(f64Func) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ using LLVM::LLVMFuncOp;
+ using LLVM::LLVMType;
+
+ static_assert(
+ std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+ "expected single result op");
+
+ LLVMType resultType = lowering.convertType(op->getResult(0)->getType())
+ .template cast<LLVM::LLVMType>();
+ LLVMType funcType = getFunctionType(resultType, operands);
+ StringRef funcName = getFunctionName(resultType);
+ if (funcName.empty())
+ return matchFailure();
+
+ LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
+ rewriter.replaceOp(op, {callOp.getResult(0)});
+ return matchSuccess();
+ }
+
+private:
+ LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
+ ArrayRef<Value> operands) const {
+ using LLVM::LLVMType;
+ SmallVector<LLVMType, 1> operandTypes;
+ for (Value operand : operands) {
+ operandTypes.push_back(operand->getType().cast<LLVMType>());
+ }
+ return LLVMType::getFunctionTy(resultType, operandTypes,
+ /*isVarArg=*/false);
+ }
+
+ StringRef getFunctionName(LLVM::LLVMType type) const {
+ if (type.isFloatTy())
+ return f32Func;
+ if (type.isDoubleTy())
+ return f64Func;
+ return "";
+ }
+
+ LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName,
+ LLVM::LLVMType funcType,
+ Operation *op) const {
+ using LLVM::LLVMFuncOp;
+
+ Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
+ if (funcOp)
+ return cast<LLVMFuncOp>(*funcOp);
+
+ mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
+ return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
+ }
+
+ const std::string f32Func;
+ const std::string f64Func;
+};
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
new file mode 100644
index 00000000000..4eddb787493
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
@@ -0,0 +1,16 @@
+if(MLIR_CUDA_CONVERSIONS_ENABLED)
+ llvm_map_components_to_libnames(nvptx "NVPTX")
+
+ add_llvm_library(MLIRGPUtoCUDATransforms
+ ConvertKernelFuncToCubin.cpp
+ ConvertLaunchFuncToCudaCalls.cpp
+ )
+ target_link_libraries(MLIRGPUtoCUDATransforms
+ MLIRGPU
+ MLIRLLVMIR
+ MLIRNVVMIR
+ MLIRPass
+ MLIRTargetNVVMIR
+ ${nvptx}
+ )
+endif()
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
new file mode 100644
index 00000000000..66a2e66f99a
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
@@ -0,0 +1,167 @@
+//===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert gpu kernel functions into a
+// corresponding binary blob that can be executed on a CUDA GPU. Currently
+// only translates the function itself but no dependencies.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Target/NVVMIR.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+
+using namespace mlir;
+
+namespace {
+// TODO(herhut): Move to shared location.
+static constexpr const char *kCubinAnnotation = "nvvm.cubin";
+
+/// A pass converting tagged kernel modules to cubin blobs.
+///
+/// If tagged as a kernel module, each contained function is translated to NVVM
+/// IR and further to PTX. A user provided CubinGenerator compiles the PTX to
+/// GPU binary code, which is then attached as an attribute to the function. The
+/// function body is erased.
+class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> {
+public:
+ GpuKernelToCubinPass(
+ CubinGenerator cubinGenerator = compilePtxToCubinForTesting)
+ : cubinGenerator(cubinGenerator) {}
+
+ void runOnModule() override {
+ ModuleOp module = getModule();
+ if (!module.getAttrOfType<UnitAttr>(
+ gpu::GPUDialect::getKernelModuleAttrName()) ||
+ !module.getName())
+ return;
+
+ // Make sure the NVPTX target is initialized.
+ LLVMInitializeNVPTXTarget();
+ LLVMInitializeNVPTXTargetInfo();
+ LLVMInitializeNVPTXTargetMC();
+ LLVMInitializeNVPTXAsmPrinter();
+
+ auto llvmModule = translateModuleToNVVMIR(module);
+ if (!llvmModule)
+ return signalPassFailure();
+
+ // Translate the module to CUBIN and attach the result as attribute to the
+ // module.
+ if (auto cubinAttr = translateGpuModuleToCubinAnnotation(
+ *llvmModule, module.getLoc(), *module.getName()))
+ module.setAttr(kCubinAnnotation, cubinAttr);
+ else
+ signalPassFailure();
+ }
+
+private:
+ static OwnedCubin compilePtxToCubinForTesting(const std::string &ptx,
+ Location, StringRef);
+
+ std::string translateModuleToPtx(llvm::Module &module,
+ llvm::TargetMachine &target_machine);
+
+ /// Converts llvmModule to cubin using the user-provided generator. Location
+ /// is used for error reporting and name is forwarded to the CUBIN generator
+ /// to use in its logging mechanisms.
+ OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Location loc,
+ StringRef name);
+
+ /// Translates llvmModule to cubin and returns the result as attribute.
+ StringAttr translateGpuModuleToCubinAnnotation(llvm::Module &llvmModule,
+ Location loc, StringRef name);
+
+ CubinGenerator cubinGenerator;
+};
+
+} // anonymous namespace
+
+std::string GpuKernelToCubinPass::translateModuleToPtx(
+ llvm::Module &module, llvm::TargetMachine &target_machine) {
+ std::string ptx;
+ {
+ llvm::raw_string_ostream stream(ptx);
+ llvm::buffer_ostream pstream(stream);
+ llvm::legacy::PassManager codegen_passes;
+ target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr,
+ llvm::CGFT_AssemblyFile);
+ codegen_passes.run(module);
+ }
+
+ return ptx;
+}
+
+OwnedCubin
+GpuKernelToCubinPass::compilePtxToCubinForTesting(const std::string &ptx,
+ Location, StringRef) {
+ const char data[] = "CUBIN";
+ return std::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
+}
+
+OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
+ Location loc,
+ StringRef name) {
+ std::unique_ptr<llvm::TargetMachine> targetMachine;
+ {
+ std::string error;
+ // TODO(herhut): Make triple configurable.
+ constexpr const char *cudaTriple = "nvptx64-nvidia-cuda";
+ llvm::Triple triple(cudaTriple);
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget("", triple, error);
+ if (target == nullptr) {
+ emitError(loc, "cannot initialize target triple");
+ return {};
+ }
+ targetMachine.reset(
+ target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {}));
+ }
+
+ // Set the data layout of the llvm module to match what the ptx target needs.
+ llvmModule.setDataLayout(targetMachine->createDataLayout());
+
+ auto ptx = translateModuleToPtx(llvmModule, *targetMachine);
+
+ return cubinGenerator(ptx, loc, name);
+}
+
+StringAttr GpuKernelToCubinPass::translateGpuModuleToCubinAnnotation(
+ llvm::Module &llvmModule, Location loc, StringRef name) {
+ auto cubin = convertModuleToCubin(llvmModule, loc, name);
+ if (!cubin)
+ return {};
+ return StringAttr::get({cubin->data(), cubin->size()}, loc->getContext());
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) {
+ return std::make_unique<GpuKernelToCubinPass>(cubinGenerator);
+}
+
+static PassRegistration<GpuKernelToCubinPass>
+ pass("test-kernel-to-cubin",
+ "Convert all kernel functions to CUDA cubin blobs");
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
new file mode 100644
index 00000000000..19dabcdafee
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -0,0 +1,424 @@
+//===- ConvertLaunchFuncToCudaCalls.cpp - MLIR CUDA lowering passes -------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert gpu.launch_func op into a sequence of
+// CUDA runtime calls. As the CUDA runtime does not have a stable published ABI,
+// this pass uses a slim runtime layer that builds on top of the public API from
+// the CUDA headers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+// To avoid name mangling, these are defined in the mini-runtime file.
+static constexpr const char *cuModuleLoadName = "mcuModuleLoad";
+static constexpr const char *cuModuleGetFunctionName = "mcuModuleGetFunction";
+static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel";
+static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper";
+static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize";
+static constexpr const char *kMcuMemHostRegister = "mcuMemHostRegister";
+
+static constexpr const char *kCubinAnnotation = "nvvm.cubin";
+static constexpr const char *kCubinStorageSuffix = "_cubin_cst";
+
+namespace {
+
+/// A pass to convert gpu.launch_func operations into a sequence of CUDA
+/// runtime calls.
+///
+/// In essence, a gpu.launch_func operations gets compiled into the following
+/// sequence of runtime calls:
+///
+/// * mcuModuleLoad -- loads the module given the cubin data
+/// * mcuModuleGetFunction -- gets a handle to the actual kernel function
+/// * mcuGetStreamHelper -- initializes a new CUDA stream
+/// * mcuLaunchKernelName -- launches the kernel on a stream
+/// * mcuStreamSynchronize -- waits for operations on the stream to finish
+///
+/// Intermediate data structures are allocated on the stack.
+class GpuLaunchFuncToCudaCallsPass
+ : public ModulePass<GpuLaunchFuncToCudaCallsPass> {
+private:
+ LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
+
+ llvm::LLVMContext &getLLVMContext() {
+ return getLLVMDialect()->getLLVMContext();
+ }
+
+ void initializeCachedTypes() {
+ const llvm::Module &module = llvmDialect->getLLVMModule();
+ llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
+ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ llvmPointerPointerType = llvmPointerType.getPointerTo();
+ llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
+ llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
+ llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ llvmIntPtrType = LLVM::LLVMType::getIntNTy(
+ llvmDialect, module.getDataLayout().getPointerSizeInBits());
+ }
+
+ LLVM::LLVMType getVoidType() { return llvmVoidType; }
+
+ LLVM::LLVMType getPointerType() { return llvmPointerType; }
+
+ LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; }
+
+ LLVM::LLVMType getInt8Type() { return llvmInt8Type; }
+
+ LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
+
+ LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
+
+ LLVM::LLVMType getIntPtrType() {
+ const llvm::Module &module = getLLVMDialect()->getLLVMModule();
+ return LLVM::LLVMType::getIntNTy(
+ getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
+ }
+
+ LLVM::LLVMType getCUResultType() {
+ // This is declared as an enum in CUDA but helpers use i32.
+ return getInt32Type();
+ }
+
+ // Allocate a void pointer on the stack.
+ Value allocatePointer(OpBuilder &builder, Location loc) {
+ auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+ builder.getI32IntegerAttr(1));
+ return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one,
+ /*alignment=*/0);
+ }
+
+ void declareCudaFunctions(Location loc);
+ Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
+ Value generateKernelNameConstant(StringRef name, Location loc,
+ OpBuilder &builder);
+ void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
+
+public:
+ // Run the dialect converter on the module.
+ void runOnModule() override {
+ // Cache the LLVMDialect for the current module.
+ llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
+ // Cache the used LLVM types.
+ initializeCachedTypes();
+
+ getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
+ translateGpuLaunchCalls(op);
+ });
+
+ // GPU kernel modules are no longer necessary since we have a global
+ // constant with the CUBIN data.
+ for (auto m : llvm::make_early_inc_range(getModule().getOps<ModuleOp>()))
+ if (m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName()))
+ m.erase();
+ }
+
+private:
+ LLVM::LLVMDialect *llvmDialect;
+ LLVM::LLVMType llvmVoidType;
+ LLVM::LLVMType llvmPointerType;
+ LLVM::LLVMType llvmPointerPointerType;
+ LLVM::LLVMType llvmInt8Type;
+ LLVM::LLVMType llvmInt32Type;
+ LLVM::LLVMType llvmInt64Type;
+ LLVM::LLVMType llvmIntPtrType;
+};
+
+} // anonymous namespace
+
+// Adds declarations for the needed helper functions from the CUDA wrapper.
+// The types in comments give the actual types expected/returned but the API
+// uses void pointers. This is fine as they have the same linkage in C.
+void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
+ ModuleOp module = getModule();
+ OpBuilder builder(module.getBody()->getTerminator());
+ if (!module.lookupSymbol(cuModuleLoadName)) {
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, cuModuleLoadName,
+ LLVM::LLVMType::getFunctionTy(
+ getCUResultType(),
+ {
+ getPointerPointerType(), /* CUmodule *module */
+ getPointerType() /* void *cubin */
+ },
+ /*isVarArg=*/false));
+ }
+ if (!module.lookupSymbol(cuModuleGetFunctionName)) {
+ // The helper uses void* instead of CUDA's opaque CUmodule and
+ // CUfunction.
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, cuModuleGetFunctionName,
+ LLVM::LLVMType::getFunctionTy(
+ getCUResultType(),
+ {
+ getPointerPointerType(), /* void **function */
+ getPointerType(), /* void *module */
+ getPointerType() /* char *name */
+ },
+ /*isVarArg=*/false));
+ }
+ if (!module.lookupSymbol(cuLaunchKernelName)) {
+ // Other than the CUDA api, the wrappers use uintptr_t to match the
+ // LLVM type if MLIR's index type, which the GPU dialect uses.
+ // Furthermore, they use void* instead of CUDA's opaque CUfunction and
+ // CUstream.
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, cuLaunchKernelName,
+ LLVM::LLVMType::getFunctionTy(
+ getCUResultType(),
+ {
+ getPointerType(), /* void* f */
+ getIntPtrType(), /* intptr_t gridXDim */
+ getIntPtrType(), /* intptr_t gridyDim */
+ getIntPtrType(), /* intptr_t gridZDim */
+ getIntPtrType(), /* intptr_t blockXDim */
+ getIntPtrType(), /* intptr_t blockYDim */
+ getIntPtrType(), /* intptr_t blockZDim */
+ getInt32Type(), /* unsigned int sharedMemBytes */
+ getPointerType(), /* void *hstream */
+ getPointerPointerType(), /* void **kernelParams */
+ getPointerPointerType() /* void **extra */
+ },
+ /*isVarArg=*/false));
+ }
+ if (!module.lookupSymbol(cuGetStreamHelperName)) {
+ // Helper function to get the current CUDA stream. Uses void* instead of
+ // CUDAs opaque CUstream.
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, cuGetStreamHelperName,
+ LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false));
+ }
+ if (!module.lookupSymbol(cuStreamSynchronizeName)) {
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, cuStreamSynchronizeName,
+ LLVM::LLVMType::getFunctionTy(getCUResultType(),
+ getPointerType() /* CUstream stream */,
+ /*isVarArg=*/false));
+ }
+ if (!module.lookupSymbol(kMcuMemHostRegister)) {
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, kMcuMemHostRegister,
+ LLVM::LLVMType::getFunctionTy(getVoidType(),
+ {
+ getPointerType(), /* void *ptr */
+ getInt64Type() /* int64 sizeBytes*/
+ },
+ /*isVarArg=*/false));
+ }
+}
+
+// Generates a parameters array to be used with a CUDA kernel launch call. The
+// arguments are extracted from the launchOp.
+// The generated code is essentially as follows:
+//
+// %array = alloca(numparams * sizeof(void *))
+// for (i : [0, NumKernelOperands))
+// %array[i] = cast<void*>(KernelOperand[i])
+// return %array
+Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
+ OpBuilder &builder) {
+ auto numKernelOperands = launchOp.getNumKernelOperands();
+ Location loc = launchOp.getLoc();
+ auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+ builder.getI32IntegerAttr(1));
+ // Provision twice as much for the `array` to allow up to one level of
+ // indirection for each argument.
+ auto arraySize = builder.create<LLVM::ConstantOp>(
+ loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands));
+ auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
+ arraySize, /*alignment=*/0);
+ for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
+ auto operand = launchOp.getKernelOperand(idx);
+ auto llvmType = operand->getType().cast<LLVM::LLVMType>();
+ Value memLocation = builder.create<LLVM::AllocaOp>(
+ loc, llvmType.getPointerTo(), one, /*alignment=*/1);
+ builder.create<LLVM::StoreOp>(loc, operand, memLocation);
+ auto casted =
+ builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+
+ // Assume all struct arguments come from MemRef. If this assumption does not
+ // hold anymore then we `launchOp` to lower from MemRefType and not after
+ // LLVMConversion has taken place and the MemRef information is lost.
+ // Extra level of indirection in the `array`:
+ // the descriptor pointer is registered via @mcuMemHostRegisterPtr
+ if (llvmType.isStructTy()) {
+ auto registerFunc =
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
+ auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
+ auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
+ ArrayRef<Value>{nullPtr, one});
+ auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
+ builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
+ builder.getSymbolRefAttr(registerFunc),
+ ArrayRef<Value>{casted, size});
+ Value memLocation = builder.create<LLVM::AllocaOp>(
+ loc, getPointerPointerType(), one, /*alignment=*/1);
+ builder.create<LLVM::StoreOp>(loc, casted, memLocation);
+ casted =
+ builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+ }
+
+ auto index = builder.create<LLVM::ConstantOp>(
+ loc, getInt32Type(), builder.getI32IntegerAttr(idx));
+ auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
+ ArrayRef<Value>{index});
+ builder.create<LLVM::StoreOp>(loc, casted, gep);
+ }
+ return array;
+}
+
+// Generates an LLVM IR dialect global that contains the name of the given
+// kernel function as a C string, and returns a pointer to its beginning.
+// The code is essentially:
+//
+// llvm.global constant @kernel_name("function_name\00")
+// func(...) {
+// %0 = llvm.addressof @kernel_name
+// %1 = llvm.constant (0 : index)
+// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
+// }
+Value GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
+ StringRef name, Location loc, OpBuilder &builder) {
+ // Make sure the trailing zero is included in the constant.
+ std::vector<char> kernelName(name.begin(), name.end());
+ kernelName.push_back('\0');
+
+ std::string globalName = llvm::formatv("{0}_kernel_name", name);
+ return LLVM::createGlobalString(
+ loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
+ LLVM::Linkage::Internal, llvmDialect);
+}
+
+// Emits LLVM IR to launch a kernel function. Expects the module that contains
+// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute of the
+// kernel function in the IR.
+// While MLIR has no global constants, also expects a cubin getter function in
+// an 'nvvm.cubingetter' attribute. Such function is expected to return a
+// pointer to the cubin blob when invoked.
+// With these given, the generated code in essence is
+//
+// %0 = call %cubingetter
+// %1 = alloca sizeof(void*)
+// call %mcuModuleLoad(%2, %1)
+// %2 = alloca sizeof(void*)
+// %3 = load %1
+// %4 = <see generateKernelNameConstant>
+// call %mcuModuleGetFunction(%2, %3, %4)
+// %5 = call %mcuGetStreamHelper()
+// %6 = load %2
+// %7 = <see setupParamsArray>
+// call %mcuLaunchKernel(%6, <launchOp operands 0..5>, 0, %5, %7, nullptr)
+// call %mcuStreamSynchronize(%5)
+void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
+ mlir::gpu::LaunchFuncOp launchOp) {
+ OpBuilder builder(launchOp);
+ Location loc = launchOp.getLoc();
+ declareCudaFunctions(loc);
+
+ auto zero = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
+ builder.getI32IntegerAttr(0));
+ // Create an LLVM global with CUBIN extracted from the kernel annotation and
+ // obtain a pointer to the first byte in it.
+ auto kernelModule =
+ getModule().lookupSymbol<ModuleOp>(launchOp.getKernelModuleName());
+ assert(kernelModule && "expected a kernel module");
+
+ auto cubinAttr = kernelModule.getAttrOfType<StringAttr>(kCubinAnnotation);
+ if (!cubinAttr) {
+ kernelModule.emitOpError()
+ << "missing " << kCubinAnnotation << " attribute";
+ return signalPassFailure();
+ }
+
+ assert(kernelModule.getName() && "expected a named module");
+ SmallString<128> nameBuffer(*kernelModule.getName());
+ nameBuffer.append(kCubinStorageSuffix);
+ Value data = LLVM::createGlobalString(
+ loc, builder, nameBuffer.str(), cubinAttr.getValue(),
+ LLVM::Linkage::Internal, getLLVMDialect());
+
+ // Emit the load module call to load the module data. Error checking is done
+ // in the called helper function.
+ auto cuModule = allocatePointer(builder, loc);
+ auto cuModuleLoad =
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
+ builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
+ builder.getSymbolRefAttr(cuModuleLoad),
+ ArrayRef<Value>{cuModule, data});
+ // Get the function from the module. The name corresponds to the name of
+ // the kernel function.
+ auto cuOwningModuleRef =
+ builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
+ auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
+ auto cuFunction = allocatePointer(builder, loc);
+ auto cuModuleGetFunction =
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
+ builder.create<LLVM::CallOp>(
+ loc, ArrayRef<Type>{getCUResultType()},
+ builder.getSymbolRefAttr(cuModuleGetFunction),
+ ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
+ // Grab the global stream needed for execution.
+ auto cuGetStreamHelper =
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
+ auto cuStream = builder.create<LLVM::CallOp>(
+ loc, ArrayRef<Type>{getPointerType()},
+ builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
+ // Invoke the function with required arguments.
+ auto cuLaunchKernel =
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
+ auto cuFunctionRef =
+ builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
+ auto paramsArray = setupParamsArray(launchOp, builder);
+ auto nullpointer =
+ builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
+ builder.create<LLVM::CallOp>(
+ loc, ArrayRef<Type>{getCUResultType()},
+ builder.getSymbolRefAttr(cuLaunchKernel),
+ ArrayRef<Value>{cuFunctionRef, launchOp.getOperand(0),
+ launchOp.getOperand(1), launchOp.getOperand(2),
+ launchOp.getOperand(3), launchOp.getOperand(4),
+ launchOp.getOperand(5), zero, /* sharedMemBytes */
+ cuStream.getResult(0), /* stream */
+ paramsArray, /* kernel params */
+ nullpointer /* extra */});
+ // Sync on the stream to make it synchronous.
+ auto cuStreamSync =
+ getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
+ builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
+ builder.getSymbolRefAttr(cuStreamSync),
+ ArrayRef<Value>(cuStream.getResult(0)));
+ launchOp.erase();
+}
+
+std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
+mlir::createConvertGpuLaunchFuncToCudaCallsPass() {
+ return std::make_unique<GpuLaunchFuncToCudaCallsPass>();
+}
+
+static PassRegistration<GpuLaunchFuncToCudaCallsPass>
+ pass("launch-func-to-cuda",
+ "Convert all launch_func ops to CUDA runtime calls");
diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
new file mode 100644
index 00000000000..b5df446abe1
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -0,0 +1,18 @@
+set(LLVM_TARGET_DEFINITIONS GPUToNVVM.td)
+mlir_tablegen(GPUToNVVM.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRGPUToNVVMIncGen)
+
+add_llvm_library(MLIRGPUtoNVVMTransforms
+ LowerGpuOpsToNVVMOps.cpp
+ )
+
+add_dependencies(MLIRGPUtoNVVMTransforms
+ MLIRGPUToNVVMIncGen)
+
+target_link_libraries(MLIRGPUtoNVVMTransforms
+ LLVMSupport
+ MLIRGPU
+ MLIRLLVMIR
+ MLIRNVVMIR
+ MLIRPass
+ )
diff --git a/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td b/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td
new file mode 100644
index 00000000000..0a6aec07041
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td
@@ -0,0 +1,21 @@
+//==-- GPUToNVVM.td - GPU Ops to NVVM Patterns ---------------*- tablegen -*==//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower GPU ops to NVVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTONVVM_TD
+#define MLIR_CONVERSION_GPUTONVVM_TD
+
+include "mlir/Dialect/GPU/GPUOps.td"
+include "mlir/Dialect/LLVMIR/NVVMOps.td"
+
+def : Pat<(GPU_BarrierOp), (NVVM_Barrier0Op)>;
+
+#endif // MLIR_CONVERSION_GPUTONVVM_TD
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
new file mode 100644
index 00000000000..08c18c1ec83
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -0,0 +1,751 @@
+//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate NVVMIR operations for higher-level
+// GPU operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/Support/FormatVariadic.h"
+
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Derived type converter for GPU to NVVM lowering. The GPU dialect uses memory
+/// space 5 for private memory attributions, but NVVM represents private
+/// memory allocations as local `alloca`s in the default address space. This
+/// converter drops the private memory space to support the use case above.
+class NVVMTypeConverter : public LLVMTypeConverter {
+public:
+ using LLVMTypeConverter::LLVMTypeConverter;
+
+ Type convertType(Type type) override {
+ auto memref = type.dyn_cast<MemRefType>();
+ if (memref &&
+ memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) {
+ type = MemRefType::get(memref.getShape(), memref.getElementType(),
+ memref.getAffineMaps());
+ }
+
+ return LLVMTypeConverter::convertType(type);
+ }
+};
+
+/// Converts all_reduce op to LLVM/NVVM ops.
+struct GPUAllReduceOpLowering : public LLVMOpLowering {
+ using AccumulatorFactory =
+ std::function<Value(Location, Value, Value, ConversionPatternRewriter &)>;
+
+ explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(gpu::AllReduceOp::getOperationName(),
+ lowering_.getDialect()->getContext(), lowering_),
+ int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Value operand = operands.front();
+
+ // TODO(csigg): Generalize to other types of accumulation.
+ assert(op->getOperand(0)->getType().isIntOrFloat());
+
+ // Create the reduction using an accumulator factory.
+ AccumulatorFactory factory =
+ getFactory(cast<gpu::AllReduceOp>(op), operand);
+ assert(factory && "failed to create accumulator factory");
+ Value result = createBlockReduce(loc, operand, factory, rewriter);
+
+ rewriter.replaceOp(op, {result});
+ return matchSuccess();
+ }
+
+private:
+ /// Returns an accumulator factory using either the op attribute or the body
+ /// region.
+ AccumulatorFactory getFactory(gpu::AllReduceOp allReduce,
+ Value operand) const {
+ if (!allReduce.body().empty()) {
+ return getFactory(allReduce.body());
+ }
+ if (allReduce.op()) {
+ auto type = operand->getType().cast<LLVM::LLVMType>();
+ return getFactory(*allReduce.op(), type.getUnderlyingType());
+ }
+ return AccumulatorFactory();
+ }
+
+ /// Returns an accumulator factory that clones the body. The body's entry
+ /// block is expected to have 2 arguments. The gpu.yield return the
+ /// accumulated value of the same type.
+ AccumulatorFactory getFactory(Region &body) const {
+ return AccumulatorFactory([&](Location loc, Value lhs, Value rhs,
+ ConversionPatternRewriter &rewriter) {
+ Block *block = rewriter.getInsertionBlock();
+ Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
+
+ // Insert accumulator body between split block.
+ BlockAndValueMapping mapping;
+ mapping.map(body.front().getArgument(0), lhs);
+ mapping.map(body.front().getArgument(1), rhs);
+ rewriter.cloneRegionBefore(body, *split->getParent(),
+ split->getIterator(), mapping);
+
+ // Add branch before inserted body, into body.
+ block = block->getNextNode();
+ rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{},
+ llvm::makeArrayRef(block), ValueRange());
+
+ // Replace all gpu.yield ops with branch out of body.
+ for (; block != split; block = block->getNextNode()) {
+ Operation *terminator = block->getTerminator();
+ if (!llvm::isa<gpu::YieldOp>(terminator))
+ continue;
+ rewriter.setInsertionPointToEnd(block);
+ rewriter.replaceOpWithNewOp<LLVM::BrOp>(
+ terminator, ArrayRef<Value>{}, llvm::makeArrayRef(split),
+ ValueRange(terminator->getOperand(0)));
+ }
+
+ // Return accumulator result.
+ rewriter.setInsertionPointToStart(split);
+ return split->addArgument(lhs->getType());
+ });
+ }
+
+ /// Returns an accumulator factory that creates an op specified by opName.
+ AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const {
+ if (type->isVectorTy() || type->isArrayTy())
+ return getFactory(opName, type->getSequentialElementType());
+
+ bool isFloatingPoint = type->isFloatingPointTy();
+
+ if (opName == "add") {
+ return isFloatingPoint ? getFactory<LLVM::FAddOp>()
+ : getFactory<LLVM::AddOp>();
+ }
+ if (opName == "mul") {
+ return isFloatingPoint ? getFactory<LLVM::FMulOp>()
+ : getFactory<LLVM::MulOp>();
+ }
+
+ return AccumulatorFactory();
+ }
+
+ /// Returns an accumulator factory that creates an op of type T.
+ template <typename T> AccumulatorFactory getFactory() const {
+ return [](Location loc, Value lhs, Value rhs,
+ ConversionPatternRewriter &rewriter) {
+ return rewriter.create<T>(loc, lhs->getType(), lhs, rhs);
+ };
+ }
+
+ /// Creates an all_reduce across the block.
+ ///
+ /// First reduce the elements within a warp. The first thread of each warp
+ /// writes the intermediate result to shared memory. After synchronizing the
+ /// block, the first warp reduces the values from shared memory. The result
+ /// is broadcasted to all threads through shared memory.
+ ///
+ /// %warp_reduce = `createWarpReduce(%operand)`
+ /// %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer
+ /// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32
+ /// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32
+ /// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1
+ /// %thread_idx = `getLinearThreadIndex()` : !llvm.i32
+ /// llvm.cond_br %is_first_lane, ^then1, ^continue1
+ /// ^then1:
+ /// %warp_id = `getWarpId()`
+ /// %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id]
+ /// llvm.store %store_dst, %warp_reduce
+ /// llvm.br ^continue1
+ /// ^continue1:
+ /// nvvm.barrier0
+ /// %num_warps = `getNumWarps()` : !llvm.i32
+ /// %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps
+ /// %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero]
+ /// llvm.cond_br %is_first_lane, ^then2, ^continue2
+ /// ^then2:
+ /// %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx]
+ /// %value = llvm.load %load_src
+ /// %result = `createWarpReduce(%value)`
+ /// llvm.store %result_ptr, %result
+ /// llvm.br ^continue2
+ /// ^continue2:
+ /// nvvm.barrier0
+ /// %result = llvm.load %result_ptr
+ /// return %result
+ ///
+ Value createBlockReduce(Location loc, Value operand,
+ AccumulatorFactory &accumFactory,
+ ConversionPatternRewriter &rewriter) const {
+ auto type = operand->getType().cast<LLVM::LLVMType>();
+
+ // Create shared memory array to store the warp reduction.
+ auto module = operand->getDefiningOp()->getParentOfType<ModuleOp>();
+ assert(module && "op must belong to a module");
+ Value sharedMemPtr =
+ createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
+
+ Value zero = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(0u));
+ Value laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
+ Value isFirstLane = rewriter.create<LLVM::ICmpOp>(
+ loc, LLVM::ICmpPredicate::eq, laneId, zero);
+ Value threadIdx = getLinearThreadIndex(loc, rewriter);
+ Value blockSize = getBlockSize(loc, rewriter);
+ Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
+
+ // Reduce elements within each warp to produce the intermediate results.
+ Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand,
+ accumFactory, rewriter);
+
+ // Write the intermediate results to shared memory, using the first lane of
+ // each warp.
+ createPredicatedBlock(loc, rewriter, isFirstLane, [&] {
+ Value warpId = getDivideByWarpSize(threadIdx, rewriter);
+ Value storeDst = rewriter.create<LLVM::GEPOp>(
+ loc, type, sharedMemPtr, ArrayRef<Value>({zero, warpId}));
+ rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
+ });
+ rewriter.create<NVVM::Barrier0Op>(loc);
+
+ Value numWarps = getNumWarps(loc, blockSize, rewriter);
+ Value isValidWarp = rewriter.create<LLVM::ICmpOp>(
+ loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
+ Value resultPtr = rewriter.create<LLVM::GEPOp>(
+ loc, type, sharedMemPtr, ArrayRef<Value>({zero, zero}));
+
+ // Use the first numWarps threads to reduce the intermediate results from
+ // shared memory. The final result is written to shared memory again.
+ createPredicatedBlock(loc, rewriter, isValidWarp, [&] {
+ Value loadSrc = rewriter.create<LLVM::GEPOp>(
+ loc, type, sharedMemPtr, ArrayRef<Value>({zero, threadIdx}));
+ Value value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
+ Value result = createWarpReduce(loc, numWarps, laneId, value,
+ accumFactory, rewriter);
+ rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
+ });
+ rewriter.create<NVVM::Barrier0Op>(loc);
+
+ // Load and return result from shared memory.
+ Value result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
+ return result;
+ }
+
+ /// Creates an if-block skeleton and calls the two factories to generate the
+ /// ops in the `then` and `else` block..
+ ///
+ /// llvm.cond_br %condition, ^then, ^continue
+ /// ^then:
+ /// %then_operands = `thenOpsFactory()`
+ /// llvm.br ^continue(%then_operands)
+ /// ^else:
+ /// %else_operands = `elseOpsFactory()`
+ /// llvm.br ^continue(%else_operands)
+ /// ^continue(%block_operands):
+ ///
+ template <typename ThenOpsFactory, typename ElseOpsFactory>
+ void createIf(Location loc, ConversionPatternRewriter &rewriter,
+ Value condition, ThenOpsFactory &&thenOpsFactory,
+ ElseOpsFactory &&elseOpsFactory) const {
+ Block *currentBlock = rewriter.getInsertionBlock();
+ auto currentPoint = rewriter.getInsertionPoint();
+
+ Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
+ Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
+ Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
+
+ rewriter.setInsertionPointToEnd(currentBlock);
+ rewriter.create<LLVM::CondBrOp>(loc, llvm::makeArrayRef(condition),
+ ArrayRef<Block *>{thenBlock, elseBlock});
+
+ auto addBranch = [&](ValueRange operands) {
+ rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{},
+ llvm::makeArrayRef(continueBlock),
+ llvm::makeArrayRef(operands));
+ };
+
+ rewriter.setInsertionPointToStart(thenBlock);
+ auto thenOperands = thenOpsFactory();
+ addBranch(thenOperands);
+
+ rewriter.setInsertionPointToStart(elseBlock);
+ auto elseOperands = elseOpsFactory();
+ addBranch(elseOperands);
+
+ assert(thenOperands.size() == elseOperands.size());
+ rewriter.setInsertionPointToStart(continueBlock);
+ for (auto operand : thenOperands)
+ continueBlock->addArgument(operand->getType());
+ }
+
+ /// Shortcut for createIf with empty else block and no block operands.
+ template <typename Factory>
+ void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter,
+ Value condition,
+ Factory &&predicatedOpsFactory) const {
+ createIf(
+ loc, rewriter, condition,
+ [&] {
+ predicatedOpsFactory();
+ return ArrayRef<Value>();
+ },
+ [&] { return ArrayRef<Value>(); });
+ }
+
+ /// Creates a reduction across the first activeWidth lanes of a warp.
+ /// The first lane returns the result, all others return values are undefined.
+ Value createWarpReduce(Location loc, Value activeWidth, Value laneId,
+ Value operand, AccumulatorFactory accumFactory,
+ ConversionPatternRewriter &rewriter) const {
+ Value warpSize = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
+ Value isPartialWarp = rewriter.create<LLVM::ICmpOp>(
+ loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
+ auto type = operand->getType().cast<LLVM::LLVMType>();
+
+ createIf(
+ loc, rewriter, isPartialWarp,
+ // Generate reduction over a (potentially) partial warp.
+ [&] {
+ Value value = operand;
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(1));
+ // Bit mask of active lanes: `(1 << activeWidth) - 1`.
+ Value activeMask = rewriter.create<LLVM::SubOp>(
+ loc, int32Type,
+ rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
+ one);
+ // Clamp lane: `activeWidth - 1`
+ Value maskAndClamp =
+ rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one);
+ auto dialect = lowering.getDialect();
+ auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
+ auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy});
+ auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+
+ // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
+ // lane is within the active range. All lanes contain the final
+ // result, but only the first lane's result is used.
+ for (int i = 1; i < kWarpSize; i <<= 1) {
+ Value offset = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(i));
+ Value shfl = rewriter.create<NVVM::ShflBflyOp>(
+ loc, shflTy, activeMask, value, offset, maskAndClamp,
+ returnValueAndIsValidAttr);
+ Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
+ loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
+ // Skip the accumulation if the shuffle op read from a lane outside
+ // of the active range.
+ createIf(
+ loc, rewriter, isActiveSrcLane,
+ [&] {
+ Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, type, shfl, rewriter.getIndexArrayAttr(0));
+ return SmallVector<Value, 1>{
+ accumFactory(loc, value, shflValue, rewriter)};
+ },
+ [&] { return llvm::makeArrayRef(value); });
+ value = rewriter.getInsertionBlock()->getArgument(0);
+ }
+ return SmallVector<Value, 1>{value};
+ },
+ // Generate a reduction over the entire warp. This is a specialization
+ // of the above reduction with unconditional accumulation.
+ [&] {
+ Value value = operand;
+ Value activeMask = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(~0u));
+ Value maskAndClamp = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
+ for (int i = 1; i < kWarpSize; i <<= 1) {
+ Value offset = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(i));
+ Value shflValue = rewriter.create<NVVM::ShflBflyOp>(
+ loc, type, activeMask, value, offset, maskAndClamp,
+ /*return_value_and_is_valid=*/UnitAttr());
+ value = accumFactory(loc, value, shflValue, rewriter);
+ }
+ return SmallVector<Value, 1>{value};
+ });
+ return rewriter.getInsertionBlock()->getArgument(0);
+ }
+
+ /// Creates a global array stored in shared memory.
+ Value createSharedMemoryArray(Location loc, ModuleOp module,
+ LLVM::LLVMType elementType, int numElements,
+ ConversionPatternRewriter &rewriter) const {
+ OpBuilder builder(module.getBodyRegion());
+
+ auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
+ StringRef name = "reduce_buffer";
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayType.cast<LLVM::LLVMType>(),
+ /*isConstant=*/false, LLVM::Linkage::Internal, name,
+ /*value=*/Attribute(), gpu::GPUDialect::getWorkgroupAddressSpace());
+
+ return rewriter.create<LLVM::AddressOfOp>(loc, globalOp);
+ }
+
+ /// Returns the index of the thread within the block.
+ Value getLinearThreadIndex(Location loc,
+ ConversionPatternRewriter &rewriter) const {
+ Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
+ Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
+ Value idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
+ Value idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type);
+ Value idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type);
+ Value tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY);
+ Value tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY);
+ Value tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX);
+ return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
+ }
+
+ /// Returns the number of threads in the block.
+ Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
+ Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
+ Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
+ Value dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type);
+ Value dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY);
+ return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
+ }
+
+ /// Returns the number of warps in the block.
+ Value getNumWarps(Location loc, Value blockSize,
+ ConversionPatternRewriter &rewriter) const {
+ auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
+ auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
+ loc, int32Type, blockSize, warpSizeMinusOne);
+ return getDivideByWarpSize(biasedBlockSize, rewriter);
+ }
+
+ /// Returns the number of active threads in the warp, not clamped to 32.
+ Value getActiveWidth(Location loc, Value threadIdx, Value blockSize,
+ ConversionPatternRewriter &rewriter) const {
+ Value threadIdxMask = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1)));
+ Value numThreadsWithSmallerWarpId =
+ rewriter.create<LLVM::AndOp>(loc, threadIdx, threadIdxMask);
+ return rewriter.create<LLVM::SubOp>(loc, blockSize,
+ numThreadsWithSmallerWarpId);
+ }
+
+ /// Returns value divided by the warp size (i.e. 32).
+ Value getDivideByWarpSize(Value value,
+ ConversionPatternRewriter &rewriter) const {
+ auto loc = value->getLoc();
+ auto warpSize = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
+ return rewriter.create<LLVM::SDivOp>(loc, int32Type, value, warpSize);
+ }
+
+ LLVM::LLVMType int32Type;
+
+ static constexpr int kWarpSize = 32;
+};
+
+struct GPUShuffleOpLowering : public LLVMOpLowering {
+ explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(gpu::ShuffleOp::getOperationName(),
+ lowering_.getDialect()->getContext(), lowering_) {}
+
+ /// Lowers a shuffle to the corresponding NVVM op.
+ ///
+ /// Convert the `width` argument into an activeMask (a bitmask which specifies
+ /// which threads participate in the shuffle) and a maskAndClamp (specifying
+ /// the highest lane which participates in the shuffle).
+ ///
+ /// %one = llvm.constant(1 : i32) : !llvm.i32
+ /// %shl = llvm.shl %one, %width : !llvm.i32
+ /// %active_mask = llvm.sub %shl, %one : !llvm.i32
+ /// %mask_and_clamp = llvm.sub %width, %one : !llvm.i32
+ /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
+ /// %mask_and_clamp : !llvm<"{ float, i1 }">
+ /// %shfl_value = llvm.extractvalue %shfl[0 : index] :
+ /// !llvm<"{ float, i1 }">
+ /// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
+ /// !llvm<"{ float, i1 }">
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ gpu::ShuffleOpOperandAdaptor adaptor(operands);
+
+ auto dialect = lowering.getDialect();
+ auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>();
+ auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
+ auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
+ auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
+
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(1));
+ // Bit mask of active lanes: `(1 << activeWidth) - 1`.
+ Value activeMask = rewriter.create<LLVM::SubOp>(
+ loc, int32Type,
+ rewriter.create<LLVM::ShlOp>(loc, int32Type, one, adaptor.width()),
+ one);
+ // Clamp lane: `activeWidth - 1`
+ Value maskAndClamp =
+ rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);
+
+ auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+ Value shfl = rewriter.create<NVVM::ShflBflyOp>(
+ loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
+ maskAndClamp, returnValueAndIsValidAttr);
+ Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
+ Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
+ loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
+
+ rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+ return matchSuccess();
+ }
+};
+
+struct GPUFuncOpLowering : LLVMOpLowering {
+ explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(),
+ typeConverter.getDialect()->getContext(),
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(operands.empty() && "func op is not expected to have operands");
+ auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
+ Location loc = gpuFuncOp.getLoc();
+
+ SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
+ workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
+ for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
+ Value attribution = en.value();
+
+ auto type = attribution->getType().dyn_cast<MemRefType>();
+ assert(type && type.hasStaticShape() && "unexpected type in attribution");
+
+ uint64_t numElements = type.getNumElements();
+
+ auto elementType =
+ lowering.convertType(type.getElementType()).cast<LLVM::LLVMType>();
+ auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
+ std::string name =
+ llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index());
+ auto globalOp = rewriter.create<LLVM::GlobalOp>(
+ gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
+ LLVM::Linkage::Internal, name, /*value=*/Attribute(),
+ gpu::GPUDialect::getWorkgroupAddressSpace());
+ workgroupBuffers.push_back(globalOp);
+ }
+
+ // Rewrite the original GPU function to an LLVM function.
+ auto funcType = lowering.convertType(gpuFuncOp.getType())
+ .cast<LLVM::LLVMType>()
+ .getPointerElementTy();
+
+ // Remap proper input types.
+ TypeConverter::SignatureConversion signatureConversion(
+ gpuFuncOp.front().getNumArguments());
+ for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i)
+ signatureConversion.addInputs(i, funcType.getFunctionParamType(i));
+
+ // Create the new function operation. Only copy those attributes that are
+ // not specific to function modeling.
+ SmallVector<NamedAttribute, 4> attributes;
+ for (const auto &attr : gpuFuncOp.getAttrs()) {
+ if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
+ attr.first.is(impl::getTypeAttrName()) ||
+ attr.first.is(gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()))
+ continue;
+ attributes.push_back(attr);
+ }
+ auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+ gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
+ LLVM::Linkage::External, attributes);
+
+ {
+ // Insert operations that correspond to converted workgroup and private
+ // memory attributions to the body of the function. This must operate on
+ // the original function, before the body region is inlined in the new
+ // function to maintain the relation between block arguments and the
+ // parent operation that assigns their semantics.
+ OpBuilder::InsertionGuard guard(rewriter);
+
+ // Rewrite workgroup memory attributions to addresses of global buffers.
+ rewriter.setInsertionPointToStart(&gpuFuncOp.front());
+ unsigned numProperArguments = gpuFuncOp.getNumArguments();
+ auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
+
+ Value zero = nullptr;
+ if (!workgroupBuffers.empty())
+ zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
+ rewriter.getI32IntegerAttr(0));
+ for (auto en : llvm::enumerate(workgroupBuffers)) {
+ LLVM::GlobalOp global = en.value();
+ Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
+ auto elementType = global.getType().getArrayElementType();
+ Value memory = rewriter.create<LLVM::GEPOp>(
+ loc, elementType.getPointerTo(global.addr_space().getZExtValue()),
+ address, ArrayRef<Value>{zero, zero});
+
+ // Build a memref descriptor pointing to the buffer to plug with the
+ // existing memref infrastructure. This may use more registers than
+ // otherwise necessary given that memref sizes are fixed, but we can try
+ // and canonicalize that away later.
+ Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
+ auto type = attribution->getType().cast<MemRefType>();
+ auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
+ type, memory);
+ signatureConversion.remapInput(numProperArguments + en.index(), descr);
+ }
+
+ // Rewrite private memory attributions to alloca'ed buffers.
+ unsigned numWorkgroupAttributions =
+ gpuFuncOp.getNumWorkgroupAttributions();
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
+ Value attribution = en.value();
+ auto type = attribution->getType().cast<MemRefType>();
+ assert(type && type.hasStaticShape() &&
+ "unexpected type in attribution");
+
+ // Explicitly drop memory space when lowering private memory
+ // attributions since NVVM models it as `alloca`s in the default
+ // memory space and does not support `alloca`s with addrspace(5).
+ auto ptrType = lowering.convertType(type.getElementType())
+ .cast<LLVM::LLVMType>()
+ .getPointerTo();
+ Value numElements = rewriter.create<LLVM::ConstantOp>(
+ gpuFuncOp.getLoc(), int64Ty,
+ rewriter.getI64IntegerAttr(type.getNumElements()));
+ Value allocated = rewriter.create<LLVM::AllocaOp>(
+ gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
+ auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
+ type, allocated);
+ signatureConversion.remapInput(
+ numProperArguments + numWorkgroupAttributions + en.index(), descr);
+ }
+ }
+
+ // Move the region to the new function, update the entry block signature.
+ rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
+ llvmFuncOp.end());
+ rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
+ signatureConversion);
+
+ {
+ // For memref-typed arguments, insert the relevant loads in the beginning
+ // of the block to comply with the LLVM dialect calling convention. This
+ // needs to be done after signature conversion to get the right types.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block &block = llvmFuncOp.front();
+ rewriter.setInsertionPointToStart(&block);
+
+ for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) {
+ if (!en.value().isa<MemRefType>() &&
+ !en.value().isa<UnrankedMemRefType>())
+ continue;
+
+ BlockArgument arg = block.getArgument(en.index());
+ Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
+ rewriter.replaceUsesOfBlockArgument(arg, loaded);
+ }
+ }
+
+ rewriter.eraseOp(gpuFuncOp);
+ return matchSuccess();
+ }
+};
+
+struct GPUReturnOpLowering : public LLVMOpLowering {
+ GPUReturnOpLowering(LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(gpu::ReturnOp::getOperationName(),
+ typeConverter.getDialect()->getContext(),
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands,
+ ArrayRef<Block *>());
+ return matchSuccess();
+ }
+};
+
+/// Import the GPU Ops to NVVM Patterns.
+#include "GPUToNVVM.cpp.inc"
+
+/// A pass that replaces all occurrences of GPU device operations with their
+/// corresponding NVVM equivalent.
+///
+/// This pass only handles device code and is not meant to be run on GPU host
+/// code.
+class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
+public:
+ void runOnModule() override {
+ ModuleOp m = getModule();
+ if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName()))
+ return;
+
+ OwningRewritePatternList patterns;
+ NVVMTypeConverter converter(m.getContext());
+ populateStdToLLVMConversionPatterns(converter, patterns);
+ populateGpuToNVVMConversionPatterns(converter, patterns);
+ ConversionTarget target(getContext());
+ target.addIllegalDialect<gpu::GPUDialect>();
+ target.addIllegalOp<LLVM::ExpOp>();
+ target.addIllegalOp<FuncOp>();
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addLegalDialect<NVVM::NVVMDialect>();
+ // TODO(csigg): Remove once we support replacing non-root ops.
+ target.addLegalOp<gpu::YieldOp>();
+ if (failed(applyPartialConversion(m, target, patterns, &converter)))
+ signalPassFailure();
+ }
+};
+
+} // anonymous namespace
+
+void mlir::populateGpuToNVVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ populateWithGenerated(converter.getDialect()->getContext(), &patterns);
+ patterns
+ .insert<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
+ NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
+ NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
+ NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
+ NVVM::GridDimYOp, NVVM::GridDimZOp>,
+ GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
+ GPUReturnOpLowering>(converter);
+ patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
+ "__nv_exp");
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() {
+ return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
+}
+
+static PassRegistration<LowerGpuOpsToNVVMOpsPass>
+ pass("convert-gpu-to-nvvm", "Generate NVVM operations for gpu operations");
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
new file mode 100644
index 00000000000..3c97e5ca86b
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRGPUtoROCDLTransforms
+ LowerGpuOpsToROCDLOps.cpp
+ )
+target_link_libraries(MLIRGPUtoROCDLTransforms
+ LLVMSupport
+ MLIRGPU
+ MLIRLLVMIR
+ MLIRROCDLIR
+ MLIRPass
+ )
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
new file mode 100644
index 00000000000..83770641bd4
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -0,0 +1,75 @@
+//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate ROCDLIR operations for higher-level
+// GPU operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+using namespace mlir;
+
+namespace {
+
+// A pass that replaces all occurrences of GPU device operations with their
+// corresponding ROCDL equivalent.
+//
+// This pass only handles device code and is not meant to be run on GPU host
+// code.
+class LowerGpuOpsToROCDLOpsPass : public ModulePass<LowerGpuOpsToROCDLOpsPass> {
+public:
+ void runOnModule() override {
+ ModuleOp m = getModule();
+ if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName()))
+ return;
+
+ OwningRewritePatternList patterns;
+ LLVMTypeConverter converter(m.getContext());
+ populateStdToLLVMConversionPatterns(converter, patterns);
+ patterns.insert<
+ GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
+ ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
+ ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, ROCDL::BlockIdXOp,
+ ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
+ ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
+ converter);
+ patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "_ocml_exp_f32",
+ "_ocml_exp_f64");
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
+ target.addIllegalOp<LLVM::ExpOp>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+ if (failed(applyPartialConversion(m, target, patterns, &converter)))
+ signalPassFailure();
+ }
+};
+
+} // anonymous namespace
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToROCDLOpsPass() {
+ return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
+}
+
+static PassRegistration<LowerGpuOpsToROCDLOpsPass>
+ pass("convert-gpu-to-rocdl",
+ "Generate ROCDL operations for gpu operations");
diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
new file mode 100644
index 00000000000..be82894461d
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRGPUtoSPIRVTransforms
+ ConvertGPUToSPIRV.cpp
+ ConvertGPUToSPIRVPass.cpp
+ )
+
+target_link_libraries(MLIRGPUtoSPIRVTransforms
+ MLIRGPU
+ MLIRIR
+ MLIRPass
+ MLIRSPIRV
+ MLIRStandardOps
+ MLIRStandardToSPIRVTransforms
+ MLIRSupport
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
new file mode 100644
index 00000000000..509457d076a
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -0,0 +1,359 @@
+//===- ConvertGPUToSPIRV.cpp - Convert GPU ops to SPIR-V dialect ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the conversion patterns from GPU ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to convert a loop::ForOp within kernel functions into spirv::LoopOp.
+class ForOpConversion final : public SPIRVOpLowering<loop::ForOp> {
+public:
+ using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
+/// builin variables.
+template <typename SourceOp, spirv::BuiltIn builtin>
+class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
+public:
+ using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a kernel function in GPU dialect within a spv.module.
+class KernelFnConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
+public:
+ KernelFnConversion(MLIRContext *context, SPIRVTypeConverter &converter,
+ ArrayRef<int64_t> workGroupSize,
+ PatternBenefit benefit = 1)
+ : SPIRVOpLowering<gpu::GPUFuncOp>(context, converter, benefit) {
+ auto config = workGroupSize.take_front(3);
+ workGroupSizeAsInt32.assign(config.begin(), config.end());
+ workGroupSizeAsInt32.resize(3, 1);
+ }
+
+ PatternMatchResult
+ matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+
+private:
+ SmallVector<int32_t, 3> workGroupSizeAsInt32;
+};
+
+/// Pattern to convert a module with gpu.kernel_module attribute to a
+/// spv.module.
+class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> {
+public:
+ using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a module terminator op to a terminator of spv.module op.
+// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined
+// in ODS.
+class KernelModuleTerminatorConversion final
+ : public SPIRVOpLowering<ModuleTerminatorOp> {
+public:
+ using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a gpu.return into a SPIR-V return.
+// TODO: This can go to DRR when GPU return has operands.
+class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
+public:
+ using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// loop::ForOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // loop::ForOp can be lowered to the structured control flow represented by
+ // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
+ // latch and the merge block the exit block. The resulting spirv::LoopOp has a
+ // single back edge from the continue to header block, and a single exit from
+ // header to merge.
+ loop::ForOpOperandAdaptor forOperands(operands);
+ auto loc = forOp.getLoc();
+ auto loopControl = rewriter.getI32IntegerAttr(
+ static_cast<uint32_t>(spirv::LoopControl::None));
+ auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
+ loopOp.addEntryAndMergeBlock();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Create the block for the header.
+ auto header = new Block();
+ // Insert the header.
+ loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
+
+ // Create the new induction variable to use.
+ BlockArgument newIndVar =
+ header->addArgument(forOperands.lowerBound()->getType());
+ Block *body = forOp.getBody();
+
+ // Apply signature conversion to the body of the forOp. It has a single block,
+ // with argument which is the induction variable. That has to be replaced with
+ // the new induction variable.
+ TypeConverter::SignatureConversion signatureConverter(
+ body->getNumArguments());
+ signatureConverter.remapInput(0, newIndVar);
+ body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
+ signatureConverter);
+
+ // Delete the loop terminator.
+ rewriter.eraseOp(body->getTerminator());
+
+ // Move the blocks from the forOp into the loopOp. This is the body of the
+ // loopOp.
+ rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
+ std::next(loopOp.body().begin(), 2));
+
+ // Branch into it from the entry.
+ rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
+ rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
+
+ // Generate the rest of the loop header.
+ rewriter.setInsertionPointToEnd(header);
+ auto mergeBlock = loopOp.getMergeBlock();
+ auto cmpOp = rewriter.create<spirv::SLessThanOp>(
+ loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+
+ // Generate instructions to increment the step of the induction variable and
+ // branch to the header.
+ Block *continueBlock = loopOp.getContinueBlock();
+ rewriter.setInsertionPointToEnd(continueBlock);
+
+ // Add the step to the induction variable and branch to the header.
+ Value updatedIndVar = rewriter.create<spirv::IAddOp>(
+ loc, newIndVar->getType(), newIndVar, forOperands.step());
+ rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+
+ rewriter.eraseOp(forOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Builtins.
+//===----------------------------------------------------------------------===//
+
+template <typename SourceOp, spirv::BuiltIn builtin>
+PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
+ SourceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto dimAttr =
+ op.getOperation()->template getAttrOfType<StringAttr>("dimension");
+ if (!dimAttr) {
+ return this->matchFailure();
+ }
+ int32_t index = 0;
+ if (dimAttr.getValue() == "x") {
+ index = 0;
+ } else if (dimAttr.getValue() == "y") {
+ index = 1;
+ } else if (dimAttr.getValue() == "z") {
+ index = 2;
+ } else {
+ return this->matchFailure();
+ }
+
+ // SPIR-V invocation builtin variables are a vector of type <3xi32>
+ auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ op, rewriter.getIntegerType(32), spirvBuiltin,
+ rewriter.getI32ArrayAttr({index}));
+ return this->matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// GPUFuncOp
+//===----------------------------------------------------------------------===//
+
+// Legalizes a GPU function as an entry SPIR-V function.
+static FuncOp
+lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter,
+ spirv::EntryPointABIAttr entryPointInfo,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
+ auto fnType = funcOp.getType();
+ if (fnType.getNumResults()) {
+ funcOp.emitError("SPIR-V lowering only supports entry functions"
+ "with no return values right now");
+ return nullptr;
+ }
+ if (fnType.getNumInputs() != argABIInfo.size()) {
+ funcOp.emitError(
+ "lowering as entry functions requires ABI info for all arguments");
+ return nullptr;
+ }
+ // Update the signature to valid SPIR-V types and add the ABI
+ // attributes. These will be "materialized" by using the
+ // LowerABIAttributesPass.
+ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ {
+ for (auto argType : enumerate(funcOp.getType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
+ }
+ auto newFuncOp = rewriter.create<FuncOp>(
+ funcOp.getLoc(), funcOp.getName(),
+ rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
+ llvm::None),
+ ArrayRef<NamedAttribute>());
+ for (const auto &namedAttr : funcOp.getAttrs()) {
+ if (namedAttr.first.is(impl::getTypeAttrName()) ||
+ namedAttr.first.is(SymbolTable::getSymbolAttrName()))
+ continue;
+ newFuncOp.setAttr(namedAttr.first, namedAttr.second);
+ }
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+ rewriter.eraseOp(funcOp);
+
+ spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo);
+ return newFuncOp;
+}
+
+PatternMatchResult
+KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!gpu::GPUDialect::isKernel(funcOp)) {
+ return matchFailure();
+ }
+
+ SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
+ for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+ argABI.push_back(spirv::getInterfaceVarABIAttr(
+ 0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
+ }
+
+ auto context = rewriter.getContext();
+ auto entryPointAttr =
+ spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context);
+ FuncOp newFuncOp = lowerAsEntryFunction(funcOp, typeConverter, rewriter,
+ entryPointAttr, argABI);
+ if (!newFuncOp) {
+ return matchFailure();
+ }
+ newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
+ rewriter.getContext()));
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleOp with gpu.kernel_module.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult KernelModuleConversion::matchAndRewrite(
+ ModuleOp moduleOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!moduleOp.getAttrOfType<UnitAttr>(
+ gpu::GPUDialect::getKernelModuleAttrName())) {
+ return matchFailure();
+ }
+ // TODO : Generalize this to account for different extensions,
+ // capabilities, extended_instruction_sets, other addressing models
+ // and memory models.
+ auto spvModule = rewriter.create<spirv::ModuleOp>(
+ moduleOp.getLoc(), spirv::AddressingModel::Logical,
+ spirv::MemoryModel::GLSL450, spirv::Capability::Shader,
+ spirv::Extension::SPV_KHR_storage_buffer_storage_class);
+ // Move the region from the module op into the SPIR-V module.
+ Region &spvModuleRegion = spvModule.getOperation()->getRegion(0);
+ rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
+ spvModuleRegion.begin());
+ // The spv.module build method adds a block with a terminator. Remove that
+ // block. The terminator of the module op in the remaining block will be
+ // legalized later.
+ spvModuleRegion.back().erase();
+ rewriter.eraseOp(moduleOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleTerminatorOp for gpu.kernel_module.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
+ ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU return inside kernel functions to SPIR-V return.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult GPUReturnOpConversion::matchAndRewrite(
+ gpu::ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!operands.empty())
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU To SPIRV Patterns.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns,
+ ArrayRef<int64_t> workGroupSize) {
+ patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
+ patterns.insert<
+ GPUReturnOpConversion, ForOpConversion, KernelModuleConversion,
+ KernelModuleTerminatorConversion,
+ LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
+ LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
+ LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
+ LaunchConfigConversion<gpu::ThreadIdOp,
+ spirv::BuiltIn::LocalInvocationId>>(context,
+ typeConverter);
+}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
new file mode 100644
index 00000000000..68392c36765
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -0,0 +1,96 @@
+//===- ConvertGPUToSPIRVPass.cpp - GPU to SPIR-V dialect lowering passes --===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a kernel function in the GPU Dialect
+// into a spv.module operation
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+namespace {
+/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
+/// that have the "gpu.kernel" attribute, i.e. those functions that are
+/// referenced in gpu::LaunchKernelOp operations. For each such function
+///
+/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
+/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
+/// replace it).
+///
+/// 2) Lower the body of the spirv::ModuleOp.
+class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+public:
+ GPUToSPIRVPass() = default;
+ GPUToSPIRVPass(const GPUToSPIRVPass &) {}
+ GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+ this->workGroupSize = workGroupSize;
+ }
+
+ void runOnModule() override;
+
+private:
+ /// Command line option to specify the workgroup size.
+ ListOption<int64_t> workGroupSize{
+ *this, "workgroup-size",
+ llvm::cl::desc(
+ "Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
+ "by z dimension of the dispatch (others will be ignored)"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+} // namespace
+
+void GPUToSPIRVPass::runOnModule() {
+ auto context = &getContext();
+ auto module = getModule();
+
+ SmallVector<Operation *, 1> kernelModules;
+ OpBuilder builder(context);
+ module.walk([&builder, &kernelModules](ModuleOp moduleOp) {
+ if (moduleOp.getAttrOfType<UnitAttr>(
+ gpu::GPUDialect::getKernelModuleAttrName())) {
+ // For each kernel module (should be only 1 for now, but that is not a
+ // requirement here), clone the module for conversion because the
+ // gpu.launch function still needs the kernel module.
+ builder.setInsertionPoint(moduleOp.getOperation());
+ kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
+ }
+ });
+
+ SPIRVTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
+ populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
+ populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+
+ ConversionTarget target(*context);
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
+
+ if (failed(applyFullConversion(kernelModules, target, patterns,
+ &typeConverter))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+ return std::make_unique<GPUToSPIRVPass>(workGroupSize);
+}
+
+static PassRegistration<GPUToSPIRVPass>
+ pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
diff --git a/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt b/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000..9d2b5dac202
--- /dev/null
+++ b/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRLinalgToLLVM
+ LinalgToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToLLVM
+)
+set(LIBS
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+ )
+
+add_dependencies(MLIRLinalgToLLVM ${LIBS})
+target_link_libraries(MLIRLinalgToLLVM ${LIBS})
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
new file mode 100644
index 00000000000..2a034fd15c5
--- /dev/null
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -0,0 +1,549 @@
+//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::LLVM;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using add = ValueBuilder<mlir::LLVM::AddOp>;
+using addi = ValueBuilder<mlir::AddIOp>;
+using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
+using cmpi = ValueBuilder<mlir::CmpIOp>;
+using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
+using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
+using gep = ValueBuilder<mlir::LLVM::GEPOp>;
+using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
+using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
+using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
+using llvm_load = ValueBuilder<LLVM::LoadOp>;
+using llvm_store = OperationBuilder<LLVM::StoreOp>;
+using llvm_select = ValueBuilder<LLVM::SelectOp>;
+using mul = ValueBuilder<mlir::LLVM::MulOp>;
+using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>;
+using sub = ValueBuilder<mlir::LLVM::SubOp>;
+using llvm_undef = ValueBuilder<mlir::LLVM::UndefOp>;
+using urem = ValueBuilder<mlir::LLVM::URemOp>;
+using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
+using llvm_return = OperationBuilder<LLVM::ReturnOp>;
+
+template <typename T>
+static LLVMType getPtrToElementType(T containerType,
+ LLVMTypeConverter &lowering) {
+ return lowering.convertType(containerType.getElementType())
+ .template cast<LLVMType>()
+ .getPointerTo();
+}
+
+// Convert the given type to the LLVM IR Dialect type. The following
+// conversions are supported:
+// - an Index type is converted into an LLVM integer type with pointer
+// bitwidth (analogous to intptr_t in C);
+// - an Integer type is converted into an LLVM integer type of the same width;
+// - an F32 type is converted into an LLVM float type
+// - a Buffer, Range or View is converted into an LLVM structure type
+// containing the respective dynamic values.
+static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
+ auto *context = t.getContext();
+ auto int64Ty = lowering.convertType(IntegerType::get(64, context))
+ .cast<LLVM::LLVMType>();
+
+ // Range descriptor contains the range bounds and the step as 64-bit integers.
+ //
+ // struct {
+ // int64_t min;
+ // int64_t max;
+ // int64_t step;
+ // };
+ if (t.isa<RangeType>())
+ return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
+
+ return Type();
+}
+
+namespace {
+/// EDSC-compatible wrapper for MemRefDescriptor.
+class BaseViewConversionHelper {
+public:
+ BaseViewConversionHelper(Type type)
+ : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
+
+ BaseViewConversionHelper(Value v) : d(v) {}
+
+ /// Wrappers around MemRefDescriptor that use EDSC builder and location.
+ Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
+ void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
+ Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
+ void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
+ Value offset() { return d.offset(rewriter(), loc()); }
+ void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
+ Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
+ void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
+ Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
+ void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
+
+ operator Value() { return d; }
+
+private:
+ OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
+ Location loc() { return ScopedContext::getLocation(); }
+
+ MemRefDescriptor d;
+};
+} // namespace
+
+// RangeOp creates a new range descriptor.
+class RangeOpConversion : public LLVMOpLowering {
+public:
+ explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto rangeOp = cast<RangeOp>(op);
+ auto rangeDescriptorTy =
+ convertLinalgType(rangeOp.getResult()->getType(), lowering);
+
+ edsc::ScopedContext context(rewriter, op->getLoc());
+
+ // Fill in an aggregate value of the descriptor.
+ RangeOpOperandAdaptor adaptor(operands);
+ Value desc = llvm_undef(rangeDescriptorTy);
+ desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
+ desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
+ desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
+/// Conversion pattern that transforms a linalg.slice op into:
+/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
+/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
+/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
+/// and stride corresponding to the region of memory within the bounds of
+/// the parent view.
+/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
+/// The linalg.slice op is replaced by the alloca'ed pointer.
+class SliceOpConversion : public LLVMOpLowering {
+public:
+ explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ SliceOpOperandAdaptor adaptor(operands);
+ BaseViewConversionHelper baseDesc(adaptor.view());
+
+ auto sliceOp = cast<SliceOp>(op);
+ auto memRefType = sliceOp.getBaseViewType();
+ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
+ .cast<LLVM::LLVMType>();
+
+ BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
+
+ // TODO(ntv): extract sizes and emit asserts.
+ SmallVector<Value, 4> strides(memRefType.getRank());
+ for (int i = 0, e = memRefType.getRank(); i < e; ++i)
+ strides[i] = baseDesc.stride(i);
+
+ auto pos = [&rewriter](ArrayRef<int64_t> values) {
+ return rewriter.getI64ArrayAttr(values);
+ };
+
+ // Compute base offset.
+ Value baseOffset = baseDesc.offset();
+ for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
+ Value indexing = adaptor.indexings()[i];
+ Value min = indexing;
+ if (sliceOp.indexing(i)->getType().isa<RangeType>())
+ min = extractvalue(int64Ty, indexing, pos(0));
+ baseOffset = add(baseOffset, mul(min, strides[i]));
+ }
+
+ // Insert the base and aligned pointers.
+ desc.setAllocatedPtr(baseDesc.allocatedPtr());
+ desc.setAlignedPtr(baseDesc.alignedPtr());
+
+ // Insert base offset.
+ desc.setOffset(baseOffset);
+
+ // Corner case, no sizes or strides: early return the descriptor.
+ if (sliceOp.getViewType().getRank() == 0)
+ return rewriter.replaceOp(op, {desc}), matchSuccess();
+
+ Value zero =
+ constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+ // Compute and insert view sizes (max - min along the range) and strides.
+ // Skip the non-range operands as they will be projected away from the view.
+ int numNewDims = 0;
+ for (auto en : llvm::enumerate(sliceOp.indexings())) {
+ Value indexing = en.value();
+ if (indexing->getType().isa<RangeType>()) {
+ int rank = en.index();
+ Value rangeDescriptor = adaptor.indexings()[rank];
+ Value min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+ Value max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+ Value step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value baseSize = baseDesc.size(rank);
+
+ // Bound upper by base view upper bound.
+ max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
+ baseSize);
+ Value size = sub(max, min);
+ // Bound lower by zero.
+ size =
+ llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
+ Value stride = mul(strides[rank], step);
+ desc.setSize(numNewDims, size);
+ desc.setStride(numNewDims, stride);
+ ++numNewDims;
+ }
+ }
+
+ rewriter.replaceOp(op, {desc});
+ return matchSuccess();
+ }
+};
+
+/// Conversion pattern that transforms a linalg.transpose op into:
+/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
+/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
+/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
+/// and stride. Size and stride are permutations of the original values.
+/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
+/// The linalg.transpose op is replaced by the alloca'ed pointer.
+class TransposeOpConversion : public LLVMOpLowering {
+public:
+ explicit TransposeOpConversion(MLIRContext *context,
+ LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Initialize the common boilerplate and alloca at the top of the FuncOp.
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ TransposeOpOperandAdaptor adaptor(operands);
+ BaseViewConversionHelper baseDesc(adaptor.view());
+
+ auto transposeOp = cast<TransposeOp>(op);
+ // No permutation, early exit.
+ if (transposeOp.permutation().isIdentity())
+ return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
+
+ BaseViewConversionHelper desc(
+ lowering.convertType(transposeOp.getViewType()));
+
+ // Copy the base and aligned pointers from the old descriptor to the new
+ // one.
+ desc.setAllocatedPtr(baseDesc.allocatedPtr());
+ desc.setAlignedPtr(baseDesc.alignedPtr());
+
+ // Copy the offset pointer from the old descriptor to the new one.
+ desc.setOffset(baseDesc.offset());
+
+ // Iterate over the dimensions and apply size/stride permutation.
+ for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
+ int sourcePos = en.index();
+ int targetPos = en.value().cast<AffineDimExpr>().getPosition();
+ desc.setSize(targetPos, baseDesc.size(sourcePos));
+ desc.setStride(targetPos, baseDesc.stride(sourcePos));
+ }
+
+ rewriter.replaceOp(op, {desc});
+ return matchSuccess();
+ }
+};
+
+// YieldOp produces and LLVM::ReturnOp.
+class YieldOpConversion : public LLVMOpLowering {
+public:
+ explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
+ return matchSuccess();
+ }
+};
+
+template <typename LinalgOp>
+static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
+ return SmallVector<Type, 4>{op->getOperandTypes()};
+}
+
+template <>
+SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
+ auto ctx = op->getContext();
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+
+ SmallVector<Type, 4> result;
+ result.reserve(numLoops + op->getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i) {
+ result.push_back(IndexType::get(ctx));
+ }
+ for (auto type : op->getOperandTypes()) {
+ result.push_back(type);
+ }
+ return result;
+}
+
+// Get a SymbolRefAttr containing the library function name for the LinalgOp.
+// If the library function does not exist, insert a declaration.
+template <typename LinalgOp>
+static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
+ PatternRewriter &rewriter) {
+ auto linalgOp = cast<LinalgOp>(op);
+ auto fnName = linalgOp.getLibraryCallName();
+ if (fnName.empty()) {
+ op->emitWarning("No library call defined for: ") << *op;
+ return {};
+ }
+
+ // fnName is a dynamic std::String, unique it via a SymbolRefAttr.
+ FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+ auto module = op->getParentOfType<ModuleOp>();
+ if (module.lookupSymbol(fnName)) {
+ return fnNameAttr;
+ }
+
+ SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
+ assert(op->getNumResults() == 0 &&
+ "Library call for linalg operation can be generated only for ops that "
+ "have void return types");
+ auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Insert before module terminator.
+ rewriter.setInsertionPoint(module.getBody(),
+ std::prev(module.getBody()->end()));
+ rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
+ ArrayRef<NamedAttribute>{});
+ return fnNameAttr;
+}
+
+Type LinalgTypeConverter::convertType(Type t) {
+ if (auto result = LLVMTypeConverter::convertType(t))
+ return result;
+ return convertLinalgType(t, *this);
+}
+
+// LinalgOpConversion<LinalgOp> creates a new call to the
+// `LinalgOp::getLibraryCallName()` function.
+// The implementation of the function can be either in the same module or in an
+// externally linked library.
+template <typename LinalgOp>
+class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
+public:
+ using OpRewritePattern<LinalgOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
+ if (!libraryCallName)
+ return this->matchFailure();
+
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
+ return this->matchSuccess();
+ }
+};
+
+/// Conversion pattern specialization for CopyOp. This kicks in when both input
+/// and output permutations are left unspecified or are the identity.
+template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ return matchFailure();
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ return matchFailure();
+
+ auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
+ if (!libraryCallName)
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
+ return matchSuccess();
+ }
+};
+
+/// Conversion pattern specialization for IndexedGenericOp.
+template <>
+class LinalgOpConversion<IndexedGenericOp>
+ : public OpRewritePattern<IndexedGenericOp> {
+public:
+ using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(IndexedGenericOp op,
+ PatternRewriter &rewriter) const override {
+ auto libraryCallName =
+ getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
+ if (!libraryCallName)
+ return this->matchFailure();
+
+ // TODO(pifon, ntv): Use induction variables values instead of zeros, when
+ // IndexedGenericOp is tiled.
+ auto zero = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+ SmallVector<Value, 4> operands;
+ operands.reserve(numLoops + op.getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i) {
+ operands.push_back(zero);
+ }
+ for (auto operand : op.getOperands()) {
+ operands.push_back(operand);
+ }
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
+ ArrayRef<Type>{}, operands);
+ return this->matchSuccess();
+ }
+};
+
+/// A non-conversion rewrite pattern kicks in to convert CopyOp with
+/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
+/// This interplays together with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
+ Value in = op.input(), out = op.output();
+
+ // If either inputPerm or outputPerm are non-identities, insert transposes.
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
+ AffineMapAttr::get(*inputPerm));
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ out = rewriter.create<linalg::TransposeOp>(
+ op.getLoc(), out, AffineMapAttr::get(*outputPerm));
+
+ // If nothing was transposed, fail and let the conversion kick in.
+ if (in == op.input() && out == op.output())
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+ return matchSuccess();
+ }
+};
+
+/// Populate the given list with patterns that convert from Linalg to Standard.
+static void
+populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
+ // attribute values such as kernel striding and dilation.
+ patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
+ LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
+ LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
+ LinalgOpConversion<IndexedGenericOp>,
+ LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
+ ctx);
+}
+
+/// Populate the given list with patterns that convert from Linalg to LLVM.
+void mlir::populateLinalgToLLVMConversionPatterns(
+ LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
+ MLIRContext *ctx) {
+ patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
+ YieldOpConversion>(ctx, converter);
+}
+
+namespace {
+struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+void ConvertLinalgToLLVMPass::runOnModule() {
+ auto module = getModule();
+
+ // Convert to the LLVM IR dialect using the converter defined above.
+ OwningRewritePatternList patterns;
+ LinalgTypeConverter converter(&getContext());
+ populateAffineToStdConversionPatterns(patterns, &getContext());
+ populateLoopToStdConversionPatterns(patterns, &getContext());
+ populateStdToLLVMConversionPatterns(converter, patterns);
+ populateVectorToLLVMConversionPatterns(converter, patterns);
+ populateLinalgToStandardConversionPatterns(patterns, &getContext());
+ populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+ if (failed(applyFullConversion(module, target, patterns, &converter)))
+ signalPassFailure();
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::linalg::createConvertLinalgToLLVMPass() {
+ return std::make_unique<ConvertLinalgToLLVMPass>();
+}
+
+static PassRegistration<ConvertLinalgToLLVMPass> pass(
+ "convert-linalg-to-llvm",
+ "Convert the operations from the linalg dialect into the LLVM dialect");
diff --git a/mlir/lib/Conversion/LoopToStandard/CMakeLists.txt b/mlir/lib/Conversion/LoopToStandard/CMakeLists.txt
new file mode 100644
index 00000000000..8f05dbd0b63
--- /dev/null
+++ b/mlir/lib/Conversion/LoopToStandard/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_llvm_library(MLIRLoopToStandard
+ ConvertLoopToStandard.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LoopToStandard
+)
+add_dependencies(
+ MLIRLoopToStandard
+
+ MLIRLoopOps
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+)
+target_link_libraries(
+ MLIRLoopToStandard
+
+ MLIRLoopOps
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+)
diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
new file mode 100644
index 00000000000..b257e9b482b
--- /dev/null
+++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
@@ -0,0 +1,269 @@
+//===- ConvertLoopToStandard.cpp - ControlFlow to CFG conversion ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert loop.for, loop.if and loop.terminator
+// ops into standard CFG ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+namespace {
+
+struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
+ void runOnOperation() override;
+};
+
+// Create a CFG subgraph for the loop around its body blocks (if the body
+// contained other loops, they have been already lowered to a flow of blocks).
+// Maintain the invariants that a CFG subgraph created for any loop has a single
+// entry and a single exit, and that the entry/exit blocks are respectively
+// first/last blocks in the parent region. The original loop operation is
+// replaced by the initialization operations that set up the initial value of
+// the loop induction variable (%iv) and computes the loop bounds that are loop-
+// invariant for affine loops. The operations following the original loop.for
+// are split out into a separate continuation (exit) block. A condition block is
+// created before the continuation block. It checks the exit condition of the
+// loop and branches either to the continuation block, or to the first block of
+// the body. Induction variable modification is appended to the last block of
+// the body (which is the exit block from the body subgraph thanks to the
+// invariant we maintain) along with a branch that loops back to the condition
+// block.
+//
+// +---------------------------------+
+// | <code before the ForOp> |
+// | <compute initial %iv value> |
+// | br cond(%iv) |
+// +---------------------------------+
+// |
+// -------| |
+// | v v
+// | +--------------------------------+
+// | | cond(%iv): |
+// | | <compare %iv to upper bound> |
+// | | cond_br %r, body, end |
+// | +--------------------------------+
+// | | |
+// | | -------------|
+// | v |
+// | +--------------------------------+ |
+// | | body-first: | |
+// | | <body contents> | |
+// | +--------------------------------+ |
+// | | |
+// | ... |
+// | | |
+// | +--------------------------------+ |
+// | | body-last: | |
+// | | <body contents> | |
+// | | %new_iv =<add step to %iv> | |
+// | | br cond(%new_iv) | |
+// | +--------------------------------+ |
+// | | |
+// |----------- |--------------------
+// v
+// +--------------------------------+
+// | end: |
+// | <code after the ForOp> |
+// +--------------------------------+
+//
+struct ForLowering : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override;
+};
+
+// Create a CFG subgraph for the loop.if operation (including its "then" and
+// optional "else" operation blocks). We maintain the invariants that the
+// subgraph has a single entry and a single exit point, and that the entry/exit
+// blocks are respectively the first/last block of the enclosing region. The
+// operations following the loop.if are split into a continuation (subgraph
+// exit) block. The condition is lowered to a chain of blocks that implement the
+// short-circuit scheme. Condition blocks are created by splitting out an empty
+// block from the block that contains the loop.if operation. They
+// conditionally branch to either the first block of the "then" region, or to
+// the first block of the "else" region. If the latter is absent, they branch
+// to the continuation block instead. The last blocks of "then" and "else"
+// regions (which are known to be exit blocks thanks to the invariant we
+// maintain).
+//
+// +--------------------------------+
+// | <code before the IfOp> |
+// | cond_br %cond, %then, %else |
+// +--------------------------------+
+// | |
+// | --------------|
+// v |
+// +--------------------------------+ |
+// | then: | |
+// | <then contents> | |
+// | br continue | |
+// +--------------------------------+ |
+// | |
+// |---------- |-------------
+// | V
+// | +--------------------------------+
+// | | else: |
+// | | <else contents> |
+// | | br continue |
+// | +--------------------------------+
+// | |
+// ------| |
+// v v
+// +--------------------------------+
+// | continue: |
+// | <code after the IfOp> |
+// +--------------------------------+
+//
+struct IfLowering : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(IfOp ifOp,
+ PatternRewriter &rewriter) const override;
+};
+
+struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
+ using OpRewritePattern<TerminatorOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(TerminatorOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+} // namespace
+
+PatternMatchResult
+ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
+ Location loc = forOp.getLoc();
+
+ // Start by splitting the block containing the 'loop.for' into two parts.
+ // The part before will get the init code, the part after will be the end
+ // point.
+ auto *initBlock = rewriter.getInsertionBlock();
+ auto initPosition = rewriter.getInsertionPoint();
+ auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
+
+ // Use the first block of the loop body as the condition block since it is
+ // the block that has the induction variable as its argument. Split out
+ // all operations from the first block into a new block. Move all body
+ // blocks from the loop body region to the region containing the loop.
+ auto *conditionBlock = &forOp.region().front();
+ auto *firstBodyBlock =
+ rewriter.splitBlock(conditionBlock, conditionBlock->begin());
+ auto *lastBodyBlock = &forOp.region().back();
+ rewriter.inlineRegionBefore(forOp.region(), endBlock);
+ auto iv = conditionBlock->getArgument(0);
+
+ // Append the induction variable stepping logic to the last body block and
+ // branch back to the condition block. Construct an expression f :
+ // (x -> x+step) and apply this expression to the induction variable.
+ rewriter.setInsertionPointToEnd(lastBodyBlock);
+ auto step = forOp.step();
+ auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
+ if (!stepped)
+ return matchFailure();
+ rewriter.create<BranchOp>(loc, conditionBlock, stepped);
+
+ // Compute loop bounds before branching to the condition.
+ rewriter.setInsertionPointToEnd(initBlock);
+ Value lowerBound = forOp.lowerBound();
+ Value upperBound = forOp.upperBound();
+ if (!lowerBound || !upperBound)
+ return matchFailure();
+ rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
+
+ // With the body block done, we can fill in the condition block.
+ rewriter.setInsertionPointToEnd(conditionBlock);
+ auto comparison =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
+
+ rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
+ ArrayRef<Value>(), endBlock, ArrayRef<Value>());
+ // Ok, we're done!
+ rewriter.eraseOp(forOp);
+ return matchSuccess();
+}
+
+PatternMatchResult
+IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
+ auto loc = ifOp.getLoc();
+
+ // Start by splitting the block containing the 'loop.if' into two parts.
+ // The part before will contain the condition, the part after will be the
+ // continuation point.
+ auto *condBlock = rewriter.getInsertionBlock();
+ auto opPosition = rewriter.getInsertionPoint();
+ auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
+
+ // Move blocks from the "then" region to the region containing 'loop.if',
+ // place it before the continuation block, and branch to it.
+ auto &thenRegion = ifOp.thenRegion();
+ auto *thenBlock = &thenRegion.front();
+ rewriter.setInsertionPointToEnd(&thenRegion.back());
+ rewriter.create<BranchOp>(loc, continueBlock);
+ rewriter.inlineRegionBefore(thenRegion, continueBlock);
+
+ // Move blocks from the "else" region (if present) to the region containing
+ // 'loop.if', place it before the continuation block and branch to it. It
+ // will be placed after the "then" regions.
+ auto *elseBlock = continueBlock;
+ auto &elseRegion = ifOp.elseRegion();
+ if (!elseRegion.empty()) {
+ elseBlock = &elseRegion.front();
+ rewriter.setInsertionPointToEnd(&elseRegion.back());
+ rewriter.create<BranchOp>(loc, continueBlock);
+ rewriter.inlineRegionBefore(elseRegion, continueBlock);
+ }
+
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
+ /*trueArgs=*/ArrayRef<Value>(), elseBlock,
+ /*falseArgs=*/ArrayRef<Value>());
+
+ // Ok, we're done!
+ rewriter.eraseOp(ifOp);
+ return matchSuccess();
+}
+
+void mlir::populateLoopToStdConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
+}
+
+void LoopToStandardPass::runOnOperation() {
+ OwningRewritePatternList patterns;
+ populateLoopToStdConversionPatterns(patterns, &getContext());
+ ConversionTarget target(getContext());
+ target.addLegalDialect<StandardOpsDialect>();
+ if (failed(applyPartialConversion(getOperation(), target, patterns)))
+ signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
+ return std::make_unique<LoopToStandardPass>();
+}
+
+static PassRegistration<LoopToStandardPass>
+ pass("convert-loop-to-std", "Convert Loop dialect to Standard dialect, "
+ "replacing structured control flow with a CFG");
diff --git a/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt b/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt
new file mode 100644
index 00000000000..2dacc800cb2
--- /dev/null
+++ b/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LIBS
+ MLIRAffineOps
+ MLIRGPU
+ MLIRIR
+ MLIRLinalg
+ MLIRPass
+ MLIRStandardOps
+ MLIRSupport
+ MLIRTransforms
+ LLVMSupport
+)
+
+add_llvm_library(MLIRLoopsToGPU
+ LoopsToGPU.cpp
+ LoopsToGPUPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LoopsToGPU
+)
+add_dependencies(MLIRLoopsToGPU ${LIBS})
+target_link_libraries(MLIRLoopsToGPU ${LIBS})
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
new file mode 100644
index 00000000000..e500d10983c
--- /dev/null
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -0,0 +1,528 @@
+//===- LoopsToGPU.cpp - Convert an affine loop nest to a GPU kernel -------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This implements a straightforward conversion of an loop nest into a GPU
+// kernel. The caller is expected to guarantee that the conversion is correct
+// or to further transform the kernel to ensure correctness.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
+
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "loops-to-gpu"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+using llvm::seq;
+
+// Extract an indexed value from KernelDim3.
+static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
+ switch (pos) {
+ case 0:
+ return dim3.x;
+ case 1:
+ return dim3.y;
+ case 2:
+ return dim3.z;
+ default:
+ llvm_unreachable("dim3 position out of bounds");
+ }
+ return nullptr;
+}
+
+// Get the lower bound-related operands of a loop operation.
+static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) {
+ return forOp.getLowerBoundOperands();
+}
+static SmallVector<Value, 1> getLowerBoundOperands(ForOp forOp) {
+ SmallVector<Value, 1> bounds(1, forOp.lowerBound());
+ return bounds;
+}
+
+// Get the upper bound-related operands of a loop operation.
+static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) {
+ return forOp.getUpperBoundOperands();
+}
+static SmallVector<Value, 1> getUpperBoundOperands(ForOp forOp) {
+ SmallVector<Value, 1> bounds(1, forOp.upperBound());
+ return bounds;
+}
+
+// Get a Value that corresponds to the loop step. If the step is an attribute,
+// materialize a corresponding constant using builder.
+static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) {
+ return builder.create<ConstantIndexOp>(forOp.getLoc(), forOp.getStep());
+}
+static Value getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); }
+
+// Get a Value for the loop lower bound. If the value requires computation,
+// materialize the instructions using builder.
+static Value getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) {
+ return lowerAffineLowerBound(forOp, builder);
+}
+static Value getOrEmitLowerBound(ForOp forOp, OpBuilder &) {
+ return forOp.lowerBound();
+}
+
+// Get a Value for the loop upper bound. If the value requires computation,
+// materialize the instructions using builder.
+static Value getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) {
+ return lowerAffineUpperBound(forOp, builder);
+}
+static Value getOrEmitUpperBound(ForOp forOp, OpBuilder &) {
+ return forOp.upperBound();
+}
+
+// Check the structure of the loop nest:
+// - there are enough loops to map to numDims;
+// - the loops are perfectly nested;
+// - the loop bounds can be computed above the outermost loop.
+// This roughly corresponds to the "matcher" part of the pattern-based
+// rewriting infrastructure.
+template <typename OpTy>
+LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) {
+ Region &limit = forOp.region();
+ for (unsigned i = 0, e = numDims; i < e; ++i) {
+ Operation *nested = &forOp.getBody()->front();
+ if (!areValuesDefinedAbove(getLowerBoundOperands(forOp), limit) ||
+ !areValuesDefinedAbove(getUpperBoundOperands(forOp), limit))
+ return forOp.emitError(
+ "loops with bounds depending on other mapped loops "
+ "are not supported");
+
+ // The innermost loop can have an arbitrary body, skip the perfect nesting
+ // check for it.
+ if (i == e - 1)
+ break;
+
+ auto begin = forOp.getBody()->begin(), end = forOp.getBody()->end();
+ if (forOp.getBody()->empty() || std::next(begin, 2) != end)
+ return forOp.emitError("expected perfectly nested loops in the body");
+
+ if (!(forOp = dyn_cast<OpTy>(nested)))
+ return nested->emitError("expected a nested loop");
+ }
+ return success();
+}
+
+template <typename OpTy>
+LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
+ unsigned numThreadDims) {
+ if (numBlockDims < 1 || numThreadDims < 1) {
+ LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+ return success();
+ }
+
+ OpBuilder builder(forOp.getOperation());
+ if (numBlockDims > 3) {
+ return forOp.emitError("cannot map to more than 3 block dimensions");
+ }
+ if (numThreadDims > 3) {
+ return forOp.emitError("cannot map to more than 3 thread dimensions");
+ }
+ return checkLoopNestMappableImpl(forOp, numBlockDims + numThreadDims);
+}
+
+template <typename OpTy>
+LogicalResult checkLoopOpMappable(OpTy forOp, unsigned numBlockDims,
+ unsigned numThreadDims) {
+ if (numBlockDims < 1 || numThreadDims < 1) {
+ LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+ return success();
+ }
+
+ if (numBlockDims > 3) {
+ return forOp.emitError("cannot map to more than 3 block dimensions");
+ }
+ if (numThreadDims > 3) {
+ return forOp.emitError("cannot map to more than 3 thread dimensions");
+ }
+ if (numBlockDims != numThreadDims) {
+ // TODO(ravishankarm) : This can probably be relaxed by having a one-trip
+ // loop for the missing dimension, but there is not reason to handle this
+ // case for now.
+ return forOp.emitError(
+ "mismatch in block dimensions and thread dimensions");
+ }
+
+ // Check that the forOp contains perfectly nested loops for numBlockDims
+ if (failed(checkLoopNestMappableImpl(forOp, numBlockDims))) {
+ return failure();
+ }
+
+ // Get to the innermost loop.
+ for (auto i : seq<unsigned>(0, numBlockDims - 1)) {
+ forOp = cast<OpTy>(&forOp.getBody()->front());
+ (void)i;
+ }
+
+ // The forOp now points to the body of the innermost loop mapped to blocks.
+ for (Operation &op : *forOp.getBody()) {
+ // If the operation is a loop, check that it is mappable to workItems.
+ if (auto innerLoop = dyn_cast<OpTy>(&op)) {
+ if (failed(checkLoopNestMappableImpl(innerLoop, numThreadDims))) {
+ return failure();
+ }
+ continue;
+ }
+ // TODO(ravishankarm) : If it is not a loop op, it is assumed that the
+ // statement is executed by all threads. It might be a collective operation,
+ // or some non-side effect instruction. Have to decide on "allowable"
+ // statements and check for those here.
+ }
+ return success();
+}
+
+namespace {
+// Helper structure that holds common state of the loop to GPU kernel
+// conversion.
+struct LoopToGpuConverter {
+ template <typename OpTy>
+ Optional<OpTy> collectBounds(OpTy forOp, unsigned numLoops);
+
+ template <typename OpTy>
+ void createLaunch(OpTy rootForOp, OpTy innermostForOp, unsigned numBlockDims,
+ unsigned numThreadDims);
+
+ // Ranges of the loops mapped to blocks or threads.
+ SmallVector<Value, 6> dims;
+ // Lower bounds of the loops mapped to blocks or threads.
+ SmallVector<Value, 6> lbs;
+ // Induction variables of the loops mapped to blocks or threads.
+ SmallVector<Value, 6> ivs;
+ // Steps of the loops mapped to blocks or threads.
+ SmallVector<Value, 6> steps;
+};
+} // namespace
+
+// Return true if the value is obviously a constant "one".
+static bool isConstantOne(Value value) {
+ if (auto def = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
+ return def.getValue() == 1;
+ return false;
+}
+
+// Collect ranges, bounds, steps and induction variables in preparation for
+// mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel.
+// This may fail if the IR for computing loop bounds cannot be constructed, for
+// example if an affine loop uses semi-affine maps. Return the last loop to be
+// mapped on success, llvm::None on failure.
+template <typename OpTy>
+Optional<OpTy> LoopToGpuConverter::collectBounds(OpTy forOp,
+ unsigned numLoops) {
+ OpBuilder builder(forOp.getOperation());
+ dims.reserve(numLoops);
+ lbs.reserve(numLoops);
+ ivs.reserve(numLoops);
+ steps.reserve(numLoops);
+ OpTy currentLoop = forOp;
+ for (unsigned i = 0; i < numLoops; ++i) {
+ Value lowerBound = getOrEmitLowerBound(currentLoop, builder);
+ Value upperBound = getOrEmitUpperBound(currentLoop, builder);
+ if (!lowerBound || !upperBound) {
+ return llvm::None;
+ }
+
+ Value range =
+ builder.create<SubIOp>(currentLoop.getLoc(), upperBound, lowerBound);
+ Value step = getOrCreateStep(currentLoop, builder);
+ if (!isConstantOne(step))
+ range = builder.create<SignedDivIOp>(currentLoop.getLoc(), range, step);
+ dims.push_back(range);
+
+ lbs.push_back(lowerBound);
+ ivs.push_back(currentLoop.getInductionVar());
+ steps.push_back(step);
+
+ if (i != numLoops - 1)
+ currentLoop = cast<OpTy>(&currentLoop.getBody()->front());
+ }
+ return currentLoop;
+}
+
+/// Given `nDims` perfectly nested loops rooted as `rootForOp`, convert them o
+/// be partitioned across workgroups or workitems. The values for the
+/// workgroup/workitem id along each dimension is passed in with `ids`. The
+/// number of workgroups/workitems along each dimension are passed in with
+/// `nids`. The innermost loop is mapped to the x-dimension, followed by the
+/// next innermost loop to y-dimension, followed by z-dimension.
+template <typename OpTy>
+OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value> ids,
+ ArrayRef<Value> nids) {
+ auto nDims = ids.size();
+ assert(nDims == nids.size());
+ for (auto dim : llvm::seq<unsigned>(0, nDims)) {
+ // TODO(ravishankarm): Don't always need to generate a loop here. If nids >=
+ // number of iterations of the original loop, this becomes a if
+ // condition. Though that does rely on how the workgroup/workitem sizes are
+ // specified to begin with.
+ mapLoopToProcessorIds(rootForOp, ids[dim], nids[dim]);
+ if (dim != nDims - 1) {
+ rootForOp = cast<OpTy>(rootForOp.getBody()->front());
+ }
+ }
+ return rootForOp;
+}
+
+/// Utility method to convert the gpu::KernelDim3 object for representing id of
+/// each workgroup/workitem and number of workgroup/workitems along a dimension
+/// of the launch into a container.
+void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids,
+ unsigned nDims, SmallVectorImpl<Value> &ids,
+ SmallVectorImpl<Value> &nids) {
+ assert(nDims <= 3 && "invalid number of launch dimensions");
+ SmallVector<Value, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x};
+ SmallVector<Value, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x};
+ ids.clear();
+ ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end());
+ nids.clear();
+ nids.append(std::next(allNids.begin(), allNids.size() - nDims),
+ allNids.end());
+}
+
+/// Generate the body of the launch operation.
+template <typename OpTy>
+LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp,
+ gpu::LaunchOp launchOp, unsigned numBlockDims,
+ unsigned numThreadDims) {
+ OpBuilder::InsertionGuard bodyInsertionGuard(builder);
+ builder.setInsertionPointToEnd(&launchOp.body().front());
+ auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc());
+
+ rootForOp.getOperation()->moveBefore(returnOp);
+ SmallVector<Value, 3> workgroupID, numWorkGroups;
+ packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims,
+ workgroupID, numWorkGroups);
+
+ // Partition the loop for mapping to workgroups.
+ auto loopOp = createGPULaunchLoops(rootForOp, workgroupID, numWorkGroups);
+
+ // Iterate over the body of the loopOp and get the loops to partition for
+ // thread blocks.
+ SmallVector<OpTy, 1> threadRootForOps;
+ for (Operation &op : *loopOp.getBody()) {
+ if (auto threadRootForOp = dyn_cast<OpTy>(&op)) {
+ threadRootForOps.push_back(threadRootForOp);
+ }
+ }
+
+ SmallVector<Value, 3> workItemID, workGroupSize;
+ packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(),
+ numThreadDims, workItemID, workGroupSize);
+ for (auto &loopOp : threadRootForOps) {
+ builder.setInsertionPoint(loopOp);
+ createGPULaunchLoops(loopOp, workItemID, workGroupSize);
+ }
+ return success();
+}
+
+// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the
+// given workgroup size and number of workgroups.
+template <typename OpTy>
+LogicalResult createLaunchFromOp(OpTy rootForOp, ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSizes) {
+ OpBuilder builder(rootForOp.getOperation());
+ if (numWorkGroups.size() > 3) {
+ return rootForOp.emitError("invalid ")
+ << numWorkGroups.size() << "-D workgroup specification";
+ }
+ auto loc = rootForOp.getLoc();
+ Value one = builder.create<ConstantOp>(
+ loc, builder.getIntegerAttr(builder.getIndexType(), 1));
+ SmallVector<Value, 3> numWorkGroups3D(3, one), workGroupSize3D(3, one);
+ for (auto numWorkGroup : enumerate(numWorkGroups)) {
+ numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value();
+ }
+ for (auto workGroupSize : enumerate(workGroupSizes)) {
+ workGroupSize3D[workGroupSize.index()] = workGroupSize.value();
+ }
+
+ // Get the values used within the region of the rootForOp but defined above
+ // it.
+ llvm::SetVector<Value> valuesToForwardSet;
+ getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(),
+ valuesToForwardSet);
+ // Also add the values used for the lb, ub, and step of the rootForOp.
+ valuesToForwardSet.insert(rootForOp.getOperands().begin(),
+ rootForOp.getOperands().end());
+ auto valuesToForward = valuesToForwardSet.takeVector();
+ auto launchOp = builder.create<gpu::LaunchOp>(
+ rootForOp.getLoc(), numWorkGroups3D[0], numWorkGroups3D[1],
+ numWorkGroups3D[2], workGroupSize3D[0], workGroupSize3D[1],
+ workGroupSize3D[2], valuesToForward);
+ if (failed(createLaunchBody(builder, rootForOp, launchOp,
+ numWorkGroups.size(), workGroupSizes.size()))) {
+ return failure();
+ }
+
+ // Replace values that are used within the region of the launchOp but are
+ // defined outside. They all are replaced with kernel arguments.
+ for (const auto &pair :
+ llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
+ Value from = std::get<0>(pair);
+ Value to = std::get<1>(pair);
+ replaceAllUsesInRegionWith(from, to, launchOp.body());
+ }
+ return success();
+}
+
+// Replace the rooted at "rootForOp" with a GPU launch operation. This expects
+// "innermostForOp" to point to the last loop to be transformed to the kernel,
+// and to have (numBlockDims + numThreadDims) perfectly nested loops between
+// "rootForOp" and "innermostForOp".
+// TODO(ravishankarm) : This method can be modified to use the
+// createLaunchFromOp method, since that is a strict generalization of this
+// method.
+template <typename OpTy>
+void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
+ unsigned numBlockDims,
+ unsigned numThreadDims) {
+ OpBuilder builder(rootForOp.getOperation());
+ // Prepare the grid and block sizes for the launch operation. If there is
+ // no loop mapped to a specific dimension, use constant "1" as its size.
+ Value constOne = (numBlockDims < 3 || numThreadDims < 3)
+ ? builder.create<ConstantIndexOp>(rootForOp.getLoc(), 1)
+ : nullptr;
+ Value gridSizeX = dims[0];
+ Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
+ Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne;
+ Value blockSizeX = dims[numBlockDims];
+ Value blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne;
+ Value blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne;
+
+ // Create a launch op and move the body region of the innermost loop to the
+ // launch op. Pass the values defined outside the outermost loop and used
+ // inside the innermost loop and loop lower bounds as kernel data arguments.
+ // Still assuming perfect nesting so there are no values other than induction
+ // variables that are defined in one loop and used in deeper loops.
+ llvm::SetVector<Value> valuesToForwardSet;
+ getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(),
+ valuesToForwardSet);
+ auto valuesToForward = valuesToForwardSet.takeVector();
+ auto originallyForwardedValues = valuesToForward.size();
+ valuesToForward.insert(valuesToForward.end(), lbs.begin(), lbs.end());
+ valuesToForward.insert(valuesToForward.end(), steps.begin(), steps.end());
+ auto launchOp = builder.create<gpu::LaunchOp>(
+ rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX,
+ blockSizeY, blockSizeZ, valuesToForward);
+ valuesToForward.resize(originallyForwardedValues);
+
+ // Replace the loop terminator (loops contain only a single block) with the
+ // gpu return and move the operations from the loop body block to the gpu
+ // launch body block. Do not move the entire block because of the difference
+ // in block arguments.
+ Operation &terminator = innermostForOp.getBody()->back();
+ Location terminatorLoc = terminator.getLoc();
+ terminator.erase();
+ builder.setInsertionPointToEnd(innermostForOp.getBody());
+ builder.create<gpu::ReturnOp>(terminatorLoc);
+ launchOp.body().front().getOperations().splice(
+ launchOp.body().front().begin(),
+ innermostForOp.getBody()->getOperations());
+
+ // Remap the loop iterators to use block/thread identifiers instead. Loops
+ // may iterate from LB with step S whereas GPU thread/block ids always iterate
+ // from 0 to N with step 1. Therefore, loop induction variables are replaced
+ // with (gpu-thread/block-id * S) + LB.
+ builder.setInsertionPointToStart(&launchOp.body().front());
+ auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(),
+ originallyForwardedValues);
+ auto stepArgumentIt = std::next(lbArgumentIt, lbs.size());
+ for (auto en : llvm::enumerate(ivs)) {
+ Value id =
+ en.index() < numBlockDims
+ ? getDim3Value(launchOp.getBlockIds(), en.index())
+ : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
+ Value step = steps[en.index()];
+ if (!isConstantOne(step))
+ id = builder.create<MulIOp>(rootForOp.getLoc(), step, id);
+
+ Value ivReplacement =
+ builder.create<AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
+ en.value()->replaceAllUsesWith(ivReplacement);
+ replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt,
+ launchOp.body());
+ std::advance(lbArgumentIt, 1);
+ std::advance(stepArgumentIt, 1);
+ }
+
+ // Remap the values defined outside the body to use kernel arguments instead.
+ // The list of kernel arguments also contains the lower bounds for loops at
+ // trailing positions, make sure we don't touch those.
+ for (const auto &pair :
+ llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
+ Value from = std::get<0>(pair);
+ Value to = std::get<1>(pair);
+ replaceAllUsesInRegionWith(from, to, launchOp.body());
+ }
+
+ // We are done and can erase the original outermost loop.
+ rootForOp.erase();
+}
+
+// Generic loop to GPU kernel conversion function.
+template <typename OpTy>
+static LogicalResult convertLoopNestToGPULaunch(OpTy forOp,
+ unsigned numBlockDims,
+ unsigned numThreadDims) {
+ if (failed(checkLoopNestMappable(forOp, numBlockDims, numThreadDims)))
+ return failure();
+
+ LoopToGpuConverter converter;
+ auto maybeInnerLoop =
+ converter.collectBounds(forOp, numBlockDims + numThreadDims);
+ if (!maybeInnerLoop)
+ return failure();
+ converter.createLaunch(forOp, *maybeInnerLoop, numBlockDims, numThreadDims);
+
+ return success();
+}
+
+// Generic loop to GPU kernel conversion function when loop is imperfectly
+// nested. The workgroup size and num workgroups is provided as input
+template <typename OpTy>
+static LogicalResult convertLoopToGPULaunch(OpTy forOp,
+ ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSize) {
+ if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(),
+ workGroupSize.size()))) {
+ return failure();
+ }
+ return createLaunchFromOp(forOp, numWorkGroups, workGroupSize);
+}
+
+LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp,
+ unsigned numBlockDims,
+ unsigned numThreadDims) {
+ return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims);
+}
+
+LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp,
+ unsigned numBlockDims,
+ unsigned numThreadDims) {
+ return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims);
+}
+
+LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp,
+ ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSizes) {
+ return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes);
+}
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
new file mode 100644
index 00000000000..c3bbf274818
--- /dev/null
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -0,0 +1,147 @@
+//===- LoopsToGPUPass.cpp - Convert a loop nest to a GPU kernel -----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/CommandLine.h"
+
+#define PASS_NAME "convert-loops-to-gpu"
+#define LOOPOP_TO_GPU_PASS_NAME "convert-loop-op-to-gpu"
+
+using namespace mlir;
+using namespace mlir::loop;
+
+static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options");
+static llvm::cl::opt<unsigned>
+ clNumBlockDims("gpu-block-dims",
+ llvm::cl::desc("Number of GPU block dimensions for mapping"),
+ llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+static llvm::cl::opt<unsigned> clNumThreadDims(
+ "gpu-thread-dims",
+ llvm::cl::desc("Number of GPU thread dimensions for mapping"),
+ llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u));
+
+static llvm::cl::OptionCategory clLoopOpToGPUCategory(LOOPOP_TO_GPU_PASS_NAME
+ " options");
+static llvm::cl::list<unsigned>
+ clNumWorkGroups("gpu-num-workgroups",
+ llvm::cl::desc("Num workgroups in the GPU launch"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clLoopOpToGPUCategory));
+static llvm::cl::list<unsigned>
+ clWorkGroupSize("gpu-workgroup-size",
+ llvm::cl::desc("Workgroup Size in the GPU launch"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clLoopOpToGPUCategory));
+
+namespace {
+// A pass that traverses top-level loops in the function and converts them to
+// GPU launch operations. Nested launches are not allowed, so this does not
+// walk the function recursively to avoid considering nested loops.
+struct ForLoopMapper : public FunctionPass<ForLoopMapper> {
+ ForLoopMapper(unsigned numBlockDims, unsigned numThreadDims)
+ : numBlockDims(numBlockDims), numThreadDims(numThreadDims) {}
+
+ void runOnFunction() override {
+ for (Block &block : getFunction())
+ for (Operation &op : llvm::make_early_inc_range(block)) {
+ if (auto forOp = dyn_cast<AffineForOp>(&op)) {
+ if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
+ numThreadDims)))
+ signalPassFailure();
+ } else if (auto forOp = dyn_cast<ForOp>(&op)) {
+ if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims,
+ numThreadDims)))
+ signalPassFailure();
+ }
+ }
+ }
+
+ unsigned numBlockDims;
+ unsigned numThreadDims;
+};
+
+// A pass that traverses top-level loops in the function and convertes them to
+// GPU launch operations. The top-level loops itself does not have to be
+// perfectly nested. The only requirement is that there be as many perfectly
+// nested loops as the size of `numWorkGroups`. Within these any loop nest has
+// to be perfectly nested upto depth equal to size of `workGroupSize`.
+struct ImperfectlyNestedForLoopMapper
+ : public FunctionPass<ImperfectlyNestedForLoopMapper> {
+ ImperfectlyNestedForLoopMapper(ArrayRef<int64_t> numWorkGroups,
+ ArrayRef<int64_t> workGroupSize)
+ : numWorkGroups(numWorkGroups.begin(), numWorkGroups.end()),
+ workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+
+ void runOnFunction() override {
+ // Insert the num work groups and workgroup sizes as constant values. This
+ // pass is only used for testing.
+ FuncOp funcOp = getFunction();
+ OpBuilder builder(funcOp.getOperation()->getRegion(0));
+ SmallVector<Value, 3> numWorkGroupsVal, workGroupSizeVal;
+ for (auto val : numWorkGroups) {
+ auto constOp = builder.create<ConstantOp>(
+ funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
+ numWorkGroupsVal.push_back(constOp);
+ }
+ for (auto val : workGroupSize) {
+ auto constOp = builder.create<ConstantOp>(
+ funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
+ workGroupSizeVal.push_back(constOp);
+ }
+ for (Block &block : getFunction()) {
+ for (Operation &op : llvm::make_early_inc_range(block)) {
+ if (auto forOp = dyn_cast<ForOp>(&op)) {
+ if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
+ workGroupSizeVal))) {
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+ }
+ SmallVector<int64_t, 3> numWorkGroups;
+ SmallVector<int64_t, 3> workGroupSize;
+};
+
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims,
+ unsigned numThreadDims) {
+ return std::make_unique<ForLoopMapper>(numBlockDims, numThreadDims);
+}
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups,
+ ArrayRef<int64_t> workGroupSize) {
+ return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups,
+ workGroupSize);
+}
+
+static PassRegistration<ForLoopMapper>
+ registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] {
+ return std::make_unique<ForLoopMapper>(clNumBlockDims.getValue(),
+ clNumThreadDims.getValue());
+ });
+
+static PassRegistration<ImperfectlyNestedForLoopMapper> loopOpToGPU(
+ LOOPOP_TO_GPU_PASS_NAME, "Convert top-level loop::ForOp to GPU kernels",
+ [] {
+ SmallVector<int64_t, 3> numWorkGroups, workGroupSize;
+ numWorkGroups.assign(clNumWorkGroups.begin(), clNumWorkGroups.end());
+ workGroupSize.assign(clWorkGroupSize.begin(), clWorkGroupSize.end());
+ return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups,
+ workGroupSize);
+ });
diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000..6334c273493
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_llvm_library(MLIRStandardToLLVM
+ ConvertStandardToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/StandardToLLVM
+)
+add_dependencies(
+ MLIRStandardToLLVM
+
+ MLIRLoopToStandard
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+)
+target_link_libraries(
+ MLIRStandardToLLVM
+
+ MLIRLoopToStandard
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+)
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
new file mode 100644
index 00000000000..0c96cc5e9c7
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -0,0 +1,2278 @@
+//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/ADT/TypeSwitch.h"
+#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+
+#define PASS_NAME "convert-std-to-llvm"
+
+static llvm::cl::OptionCategory
+ clOptionsCategory("Standard to LLVM lowering options");
+
+static llvm::cl::opt<bool>
+ clUseAlloca(PASS_NAME "-use-alloca",
+ llvm::cl::desc("Replace emission of malloc/free by alloca"),
+ llvm::cl::init(false));
+
+LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
+ : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
+ assert(llvmDialect && "LLVM IR dialect is not registered");
+ module = &llvmDialect->getLLVMModule();
+}
+
+// Get the LLVM context.
+llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
+ return module->getContext();
+}
+
+// Extract an LLVM IR type from the LLVM IR dialect type.
+LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) {
+ if (!type)
+ return nullptr;
+ auto *mlirContext = type.getContext();
+ auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedLLVMType)
+ emitError(UnknownLoc::get(mlirContext),
+ "conversion resulted in a non-LLVM type");
+ return wrappedLLVMType;
+}
+
+LLVM::LLVMType LLVMTypeConverter::getIndexType() {
+ return LLVM::LLVMType::getIntNTy(
+ llvmDialect, module->getDataLayout().getPointerSizeInBits());
+}
+
+Type LLVMTypeConverter::convertIndexType(IndexType type) {
+ return getIndexType();
+}
+
+Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
+ return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
+}
+
+Type LLVMTypeConverter::convertFloatType(FloatType type) {
+ switch (type.getKind()) {
+ case mlir::StandardTypes::F32:
+ return LLVM::LLVMType::getFloatTy(llvmDialect);
+ case mlir::StandardTypes::F64:
+ return LLVM::LLVMType::getDoubleTy(llvmDialect);
+ case mlir::StandardTypes::F16:
+ return LLVM::LLVMType::getHalfTy(llvmDialect);
+ case mlir::StandardTypes::BF16: {
+ auto *mlirContext = llvmDialect->getContext();
+ return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
+ Type();
+ }
+ default:
+ llvm_unreachable("non-float type in convertFloatType");
+ }
+}
+
+// Except for signatures, MLIR function types are converted into LLVM
+// pointer-to-function types.
+Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
+ SignatureConversion conversion(type.getNumInputs());
+ LLVM::LLVMType converted =
+ convertFunctionSignature(type, /*isVariadic=*/false, conversion);
+ return converted.getPointerTo();
+}
+
+// Function types are converted to LLVM Function types by recursively converting
+// argument and result types. If MLIR Function has zero results, the LLVM
+// Function has one VoidType result. If MLIR Function has more than one result,
+// they are into an LLVM StructType in their order of appearance.
+LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
+ FunctionType type, bool isVariadic,
+ LLVMTypeConverter::SignatureConversion &result) {
+ // Convert argument types one by one and check for errors.
+ for (auto &en : llvm::enumerate(type.getInputs())) {
+ Type type = en.value();
+ auto converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
+ if (!converted)
+ return {};
+ if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>())
+ converted = converted.getPointerTo();
+ result.addInputs(en.index(), converted);
+ }
+
+ SmallVector<LLVM::LLVMType, 8> argTypes;
+ argTypes.reserve(llvm::size(result.getConvertedTypes()));
+ for (Type type : result.getConvertedTypes())
+ argTypes.push_back(unwrap(type));
+
+ // If function does not return anything, create the void result type,
+ // if it returns on element, convert it, otherwise pack the result types into
+ // a struct.
+ LLVM::LLVMType resultType =
+ type.getNumResults() == 0
+ ? LLVM::LLVMType::getVoidTy(llvmDialect)
+ : unwrap(packFunctionResults(type.getResults()));
+ if (!resultType)
+ return {};
+ return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
+}
+
+// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
+// contains:
+// 1. the pointer to the data buffer, followed by
+// 2. a lowered `index`-type integer containing the distance between the
+// beginning of the buffer and the first element to be accessed through the
+// view, followed by
+// 3. an array containing as many `index`-type integers as the rank of the
+// MemRef: the array represents the size, in number of elements, of the memref
+// along the given dimension. For constant MemRef dimensions, the
+// corresponding size entry is a constant whose runtime value must match the
+// static value, followed by
+// 4. a second array containing as many `index`-type integers as the rank of
+// the MemRef: the second array represents the "stride" (in tensor abstraction
+// sense), i.e. the number of consecutive elements of the underlying buffer.
+// TODO(ntv, zinenko): add assertions for the static cases.
+//
+// template <typename Elem, size_t Rank>
+// struct {
+// Elem *allocatedPtr;
+// Elem *alignedPtr;
+// int64_t offset;
+// int64_t sizes[Rank]; // omitted when rank == 0
+// int64_t strides[Rank]; // omitted when rank == 0
+// };
+static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
+static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
+static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
+static constexpr unsigned kSizePosInMemRefDescriptor = 3;
+static constexpr unsigned kStridePosInMemRefDescriptor = 4;
+Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
+ assert(strideSuccess &&
+ "Non-strided layout maps must have been normalized away");
+ (void)strideSuccess;
+ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+ if (!elementType)
+ return {};
+ auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
+ auto indexTy = getIndexType();
+ auto rank = type.getRank();
+ if (rank > 0) {
+ auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
+ return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy);
+ }
+ return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
+}
+
+// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
+// contains:
+// 1. int64_t rank, the dynamic rank of this MemRef
+// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
+// stack allocated (alloca) copy of a MemRef descriptor that got casted to
+// be unranked.
+
+static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
+static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
+
+Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
+ auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
+}
+
+// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
+// n > 1.
+// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
+// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
+Type LLVMTypeConverter::convertVectorType(VectorType type) {
+ auto elementType = unwrap(convertType(type.getElementType()));
+ if (!elementType)
+ return {};
+ auto vectorType =
+ LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
+ auto shape = type.getShape();
+ for (int i = shape.size() - 2; i >= 0; --i)
+ vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
+ return vectorType;
+}
+
+// Dispatch based on the actual type. Return null type on error.
+Type LLVMTypeConverter::convertStandardType(Type t) {
+ return TypeSwitch<Type, Type>(t)
+ .Case([&](FloatType type) { return convertFloatType(type); })
+ .Case([&](FunctionType type) { return convertFunctionType(type); })
+ .Case([&](IndexType type) { return convertIndexType(type); })
+ .Case([&](IntegerType type) { return convertIntegerType(type); })
+ .Case([&](MemRefType type) { return convertMemRefType(type); })
+ .Case([&](UnrankedMemRefType type) {
+ return convertUnrankedMemRefType(type);
+ })
+ .Case([&](VectorType type) { return convertVectorType(type); })
+ .Case([](LLVM::LLVMType type) { return type; })
+ .Default([](Type) { return Type(); });
+}
+
+LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
+ LLVMTypeConverter &lowering_,
+ PatternBenefit benefit)
+ : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
+
+/*============================================================================*/
+/* StructBuilder implementation */
+/*============================================================================*/
+StructBuilder::StructBuilder(Value v) : value(v) {
+ assert(value != nullptr && "value cannot be null");
+ structType = value->getType().cast<LLVM::LLVMType>();
+}
+
+Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
+ unsigned pos) {
+ Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+ return builder.create<LLVM::ExtractValueOp>(loc, type, value,
+ builder.getI64ArrayAttr(pos));
+}
+
+void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
+ Value ptr) {
+ value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
+ builder.getI64ArrayAttr(pos));
+}
+/*============================================================================*/
+/* MemRefDescriptor implementation */
+/*============================================================================*/
+
+/// Construct a helper for the given descriptor value.
+MemRefDescriptor::MemRefDescriptor(Value descriptor)
+ : StructBuilder(descriptor) {
+ assert(value != nullptr && "value cannot be null");
+ indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ kOffsetPosInMemRefDescriptor);
+}
+
+/// Builds IR creating an `undef` value of the descriptor type.
+MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
+ Type descriptorType) {
+
+ Value descriptor =
+ builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+ return MemRefDescriptor(descriptor);
+}
+
+/// Builds IR creating a MemRef descriptor that represents `type` and
+/// populates it with static shape and stride information extracted from the
+/// type.
+MemRefDescriptor
+MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ MemRefType type, Value memory) {
+ assert(type.hasStaticShape() && "unexpected dynamic shape");
+ assert(type.getAffineMaps().empty() && "unexpected layout map");
+
+ auto convertedType = typeConverter.convertType(type);
+ assert(convertedType && "unexpected failure in memref type conversion");
+
+ auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
+ descr.setAllocatedPtr(builder, loc, memory);
+ descr.setAlignedPtr(builder, loc, memory);
+ descr.setConstantOffset(builder, loc, 0);
+
+ // Fill in sizes and strides, in reverse order to simplify stride
+ // calculation.
+ uint64_t runningStride = 1;
+ for (unsigned i = type.getRank(); i > 0; --i) {
+ unsigned dim = i - 1;
+ descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
+ descr.setConstantStride(builder, loc, dim, runningStride);
+ runningStride *= type.getDimSize(dim);
+ }
+ return descr;
+}
+
+/// Builds IR extracting the allocated pointer from the descriptor.
+Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
+}
+
+/// Builds IR inserting the allocated pointer into the descriptor.
+void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
+ Value ptr) {
+ setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
+}
+
+/// Builds IR extracting the aligned pointer from the descriptor.
+Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
+}
+
+/// Builds IR inserting the aligned pointer into the descriptor.
+void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
+ Value ptr) {
+ setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
+}
+
+// Creates a constant Op producing a value of `resultType` from an index-typed
+// integer attribute.
+static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
+ Type resultType, int64_t value) {
+ return builder.create<LLVM::ConstantOp>(
+ loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
+}
+
+/// Builds IR extracting the offset from the descriptor.
+Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
+}
+
+/// Builds IR inserting the offset into the descriptor.
+void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
+ Value offset) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, offset,
+ builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
+}
+
+/// Builds IR inserting the offset into the descriptor.
+void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
+ uint64_t offset) {
+ setOffset(builder, loc,
+ createIndexAttrConstant(builder, loc, indexType, offset));
+}
+
+/// Builds IR extracting the pos-th size from the descriptor.
+Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
+}
+
+/// Builds IR inserting the pos-th size into the descriptor
+void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
+ Value size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, size,
+ builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
+}
+
+/// Builds IR inserting the pos-th size into the descriptor
+void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
+ unsigned pos, uint64_t size) {
+ setSize(builder, loc, pos,
+ createIndexAttrConstant(builder, loc, indexType, size));
+}
+
+/// Builds IR extracting the pos-th size from the descriptor.
+Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
+}
+
+/// Builds IR inserting the pos-th stride into the descriptor
+void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
+ Value stride) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, stride,
+ builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
+}
+
+/// Builds IR inserting the pos-th stride into the descriptor
+void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
+ unsigned pos, uint64_t stride) {
+ setStride(builder, loc, pos,
+ createIndexAttrConstant(builder, loc, indexType, stride));
+}
+
+LLVM::LLVMType MemRefDescriptor::getElementType() {
+ return value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ kAlignedPtrPosInMemRefDescriptor);
+}
+
+/*============================================================================*/
+/* UnrankedMemRefDescriptor implementation */
+/*============================================================================*/
+
+/// Construct a helper for the given descriptor value.
+UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
+ : StructBuilder(descriptor) {}
+
+/// Builds IR creating an `undef` value of the descriptor type.
+UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
+ Location loc,
+ Type descriptorType) {
+ Value descriptor =
+ builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+ return UnrankedMemRefDescriptor(descriptor);
+}
+Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
+}
+void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
+ Value v) {
+ setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
+}
+Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
+ Location loc) {
+ return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
+}
+void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
+ Location loc, Value v) {
+ setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
+}
+namespace {
+// Base class for Standard to LLVM IR op conversions. Matches the Op type
+// provided as template argument. Carries a reference to the LLVM dialect in
+// case it is necessary for rewriters.
+template <typename SourceOp>
+class LLVMLegalizationPattern : public LLVMOpLowering {
+public:
+ // Construct a conversion pattern.
+ explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
+ LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
+ lowering_),
+ dialect(dialect_) {}
+
+ // Get the LLVM IR dialect.
+ LLVM::LLVMDialect &getDialect() const { return dialect; }
+ // Get the LLVM context.
+ llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
+ // Get the LLVM module in which the types are constructed.
+ llvm::Module &getModule() const { return dialect.getLLVMModule(); }
+
+ // Get the MLIR type wrapping the LLVM integer type whose bit width is defined
+ // by the pointer size used in the LLVM module.
+ LLVM::LLVMType getIndexType() const {
+ return LLVM::LLVMType::getIntNTy(
+ &dialect, getModule().getDataLayout().getPointerSizeInBits());
+ }
+
+ LLVM::LLVMType getVoidType() const {
+ return LLVM::LLVMType::getVoidTy(&dialect);
+ }
+
+ // Get the MLIR type wrapping the LLVM i8* type.
+ LLVM::LLVMType getVoidPtrType() const {
+ return LLVM::LLVMType::getInt8PtrTy(&dialect);
+ }
+
+ // Create an LLVM IR pseudo-operation defining the given index constant.
+ Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
+ uint64_t value) const {
+ return createIndexAttrConstant(builder, loc, getIndexType(), value);
+ }
+
+protected:
+ LLVM::LLVMDialect &dialect;
+};
+
+struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
+ using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto funcOp = cast<FuncOp>(op);
+ FunctionType type = funcOp.getType();
+
+ // Store the positions of memref-typed arguments so that we can emit loads
+ // from them to follow the calling convention.
+ SmallVector<unsigned, 4> promotedArgIndices;
+ promotedArgIndices.reserve(type.getNumInputs());
+ for (auto en : llvm::enumerate(type.getInputs())) {
+ if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>())
+ promotedArgIndices.push_back(en.index());
+ }
+
+ // Convert the original function arguments. Struct arguments are promoted to
+ // pointer to struct arguments to allow calling external functions with
+ // various ABIs (e.g. compiled from C/C++ on platform X).
+ auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
+ TypeConverter::SignatureConversion result(funcOp.getNumArguments());
+ auto llvmType = lowering.convertFunctionSignature(
+ funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
+
+ // Only retain those attributes that are not constructed by build.
+ SmallVector<NamedAttribute, 4> attributes;
+ for (const auto &attr : funcOp.getAttrs()) {
+ if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
+ attr.first.is(impl::getTypeAttrName()) ||
+ attr.first.is("std.varargs"))
+ continue;
+ attributes.push_back(attr);
+ }
+
+ // Create an LLVM function, use external linkage by default until MLIR
+ // functions have linkage.
+ auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+ op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External,
+ attributes);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ // Tell the rewriter to convert the region signature.
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+
+ // Insert loads from memref descriptor pointers in function bodies.
+ if (!newFuncOp.getBody().empty()) {
+ Block *firstBlock = &newFuncOp.getBody().front();
+ rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
+ for (unsigned idx : promotedArgIndices) {
+ BlockArgument arg = firstBlock->getArgument(idx);
+ Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
+ rewriter.replaceUsesOfBlockArgument(arg, loaded);
+ }
+ }
+
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+
+//////////////// Support for Lowering operations on n-D vectors ////////////////
+namespace {
+// Helper struct to "unroll" operations on n-D vectors in terms of operations on
+// 1-D LLVM vectors.
+struct NDVectorTypeInfo {
+ // LLVM array struct which encodes n-D vectors.
+ LLVM::LLVMType llvmArrayTy;
+ // LLVM vector type which encodes the inner 1-D vector type.
+ LLVM::LLVMType llvmVectorTy;
+ // Multiplicity of llvmArrayTy to llvmVectorTy.
+ SmallVector<int64_t, 4> arraySizes;
+};
+} // namespace
+
+// For >1-D vector types, extracts the necessary information to iterate over all
+// 1-D subvectors in the underlying llrepresentation of the n-D vector
+// Iterates on the llvm array type until we hit a non-array type (which is
+// asserted to be an llvm vector type).
+static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
+ LLVMTypeConverter &converter) {
+ assert(vectorType.getRank() > 1 && "expected >1D vector type");
+ NDVectorTypeInfo info;
+ info.llvmArrayTy =
+ converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
+ if (!info.llvmArrayTy)
+ return info;
+ info.arraySizes.reserve(vectorType.getRank() - 1);
+ auto llvmTy = info.llvmArrayTy;
+ while (llvmTy.isArrayTy()) {
+ info.arraySizes.push_back(llvmTy.getArrayNumElements());
+ llvmTy = llvmTy.getArrayElementType();
+ }
+ if (!llvmTy.isVectorTy())
+ return info;
+ info.llvmVectorTy = llvmTy;
+ return info;
+}
+
+// Express `linearIndex` in terms of coordinates of `basis`.
+// Returns the empty vector when linearIndex is out of the range [0, P] where
+// P is the product of all the basis coordinates.
+//
+// Prerequisites:
+// Basis is an array of nonnegative integers (signed type inherited from
+// vector shape type).
+static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
+ unsigned linearIndex) {
+ SmallVector<int64_t, 4> res;
+ res.reserve(basis.size());
+ for (unsigned basisElement : llvm::reverse(basis)) {
+ res.push_back(linearIndex % basisElement);
+ linearIndex = linearIndex / basisElement;
+ }
+ if (linearIndex > 0)
+ return {};
+ std::reverse(res.begin(), res.end());
+ return res;
+}
+
+// Iterate of linear index, convert to coords space and insert splatted 1-D
+// vector in each position.
+template <typename Lambda>
+void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
+ Lambda fun) {
+ unsigned ub = 1;
+ for (auto s : info.arraySizes)
+ ub *= s;
+ for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
+ auto coords = getCoordinates(info.arraySizes, linearIndex);
+ // Linear index is out of bounds, we are done.
+ if (coords.empty())
+ break;
+ assert(coords.size() == info.arraySizes.size());
+ auto position = builder.getI64ArrayAttr(coords);
+ fun(position);
+ }
+}
+////////////// End Support for Lowering operations on n-D vectors //////////////
+
+// Basic lowering implementation for one-to-one rewriting from Standard Ops to
+// LLVM Dialect Ops.
+template <typename SourceOp, typename TargetOp>
+struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+ using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+ using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;
+
+ // Convert the type of the result to an LLVM type, pass operands as is,
+ // preserve attributes.
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ unsigned numResults = op->getNumResults();
+
+ Type packedType;
+ if (numResults != 0) {
+ packedType = this->lowering.packFunctionResults(
+ llvm::to_vector<4>(op->getResultTypes()));
+ if (!packedType)
+ return this->matchFailure();
+ }
+
+ auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
+ op->getAttrs());
+
+ // If the operation produced 0 or 1 result, return them immediately.
+ if (numResults == 0)
+ return rewriter.eraseOp(op), this->matchSuccess();
+ if (numResults == 1)
+ return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
+ this->matchSuccess();
+
+ // Otherwise, it had been converted to an operation producing a structure.
+ // Extract individual results from the structure and return them as list.
+ SmallVector<Value, 4> results;
+ results.reserve(numResults);
+ for (unsigned i = 0; i < numResults; ++i) {
+ auto type = this->lowering.convertType(op->getResult(i)->getType());
+ results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ op->getLoc(), type, newOp.getOperation()->getResult(0),
+ rewriter.getI64ArrayAttr(i)));
+ }
+ rewriter.replaceOp(op, results);
+ return this->matchSuccess();
+ }
+};
+
+template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
+ static_assert(
+ std::is_base_of<
+ typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
+ SourceOp>::value,
+ "wrong operand count");
+};
+
+template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
+ static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
+ "expected a single operand");
+};
+
+template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
+ OpCountValidator<SourceOp, OpCount>();
+}
+
+// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
+// Ops for N-ary ops with one result. This supports higher-dimensional vector
+// types.
+template <typename SourceOp, typename TargetOp, unsigned OpCount>
+struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+ using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+ using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
+
+ // Convert the type of the result to an LLVM type, pass operands as is,
+ // preserve attributes.
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ ValidateOpCount<SourceOp, OpCount>();
+ static_assert(
+ std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+ "expected single result op");
+ static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+ SourceOp>::value,
+ "expected same operands and result type");
+
+ // Cannot convert ops if their operands are not of LLVM type.
+ for (Value operand : operands) {
+ if (!operand || !operand->getType().isa<LLVM::LLVMType>())
+ return this->matchFailure();
+ }
+
+ auto loc = op->getLoc();
+ auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
+
+ if (!llvmArrayTy.isArrayTy()) {
+ auto newOp = rewriter.create<TargetOp>(
+ op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResult());
+ return this->matchSuccess();
+ }
+
+ auto vectorType = op->getResult(0)->getType().dyn_cast<VectorType>();
+ if (!vectorType)
+ return this->matchFailure();
+ auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering);
+ auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+ if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
+ return this->matchFailure();
+
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ // For this unrolled `position` corresponding to the `linearIndex`^th
+ // element, extract operand vectors
+ SmallVector<Value, OpCount> extractedOperands;
+ for (unsigned i = 0; i < OpCount; ++i) {
+ extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operands[i], position));
+ }
+ Value newVal = rewriter.create<TargetOp>(
+ loc, llvmVectorTy, extractedOperands, op->getAttrs());
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
+ newVal, position);
+ });
+ rewriter.replaceOp(op, desc);
+ return this->matchSuccess();
+ }
+};
+
+template <typename SourceOp, typename TargetOp>
+using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
+template <typename SourceOp, typename TargetOp>
+using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
+
+// Specific lowerings.
+// FIXME: this should be tablegen'ed.
+struct AbsFOpLowering : public UnaryOpLLVMOpLowering<AbsFOp, LLVM::FAbsOp> {
+ using Super::Super;
+};
+struct CeilFOpLowering : public UnaryOpLLVMOpLowering<CeilFOp, LLVM::FCeilOp> {
+ using Super::Super;
+};
+struct CosOpLowering : public UnaryOpLLVMOpLowering<CosOp, LLVM::CosOp> {
+ using Super::Super;
+};
+struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> {
+ using Super::Super;
+};
+struct LogOpLowering : public UnaryOpLLVMOpLowering<LogOp, LLVM::LogOp> {
+ using Super::Super;
+};
+struct Log10OpLowering : public UnaryOpLLVMOpLowering<Log10Op, LLVM::Log10Op> {
+ using Super::Super;
+};
+struct Log2OpLowering : public UnaryOpLLVMOpLowering<Log2Op, LLVM::Log2Op> {
+ using Super::Super;
+};
+struct NegFOpLowering : public UnaryOpLLVMOpLowering<NegFOp, LLVM::FNegOp> {
+ using Super::Super;
+};
+struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
+ using Super::Super;
+};
+struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
+ using Super::Super;
+};
+struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
+ using Super::Super;
+};
+struct SignedDivIOpLowering
+ : public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
+ using Super::Super;
+};
+struct UnsignedDivIOpLowering
+ : public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
+ using Super::Super;
+};
+struct SignedRemIOpLowering
+ : public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> {
+ using Super::Super;
+};
+struct UnsignedRemIOpLowering
+ : public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> {
+ using Super::Super;
+};
+struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
+ using Super::Super;
+};
+struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
+ using Super::Super;
+};
+struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
+ using Super::Super;
+};
+struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
+ using Super::Super;
+};
+struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
+ using Super::Super;
+};
+struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
+ using Super::Super;
+};
+struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
+ using Super::Super;
+};
+struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
+ using Super::Super;
+};
+struct CopySignOpLowering
+ : public BinaryOpLLVMOpLowering<CopySignOp, LLVM::CopySignOp> {
+ using Super::Super;
+};
+struct SelectOpLowering
+ : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
+ using Super::Super;
+};
+struct ConstLLVMOpLowering
+ : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
+ using Super::Super;
+};
+struct ShiftLeftOpLowering
+ : public OneToOneLLVMOpLowering<ShiftLeftOp, LLVM::ShlOp> {
+ using Super::Super;
+};
+struct SignedShiftRightOpLowering
+ : public OneToOneLLVMOpLowering<SignedShiftRightOp, LLVM::AShrOp> {
+ using Super::Super;
+};
+struct UnsignedShiftRightOpLowering
+ : public OneToOneLLVMOpLowering<UnsignedShiftRightOp, LLVM::LShrOp> {
+ using Super::Super;
+};
+
+// Check if the MemRefType `type` is supported by the lowering. We currently
+// only support memrefs with identity maps.
+static bool isSupportedMemRefType(MemRefType type) {
+ return type.getAffineMaps().empty() ||
+ llvm::all_of(type.getAffineMaps(),
+ [](AffineMap map) { return map.isIdentity(); });
+}
+
+// An `alloc` is converted into a definition of a memref descriptor value and
+// a call to `malloc` to allocate the underlying data buffer. The memref
+// descriptor is of the LLVM structure type where:
+// 1. the first element is a pointer to the allocated (typed) data buffer,
+// 2. the second element is a pointer to the (typed) payload, aligned to the
+// specified alignment,
+// 3. the remaining elements serve to store all the sizes and strides of the
+// memref using LLVM-converted `index` type.
+//
+// Alignment is obtained by allocating `alignment - 1` more bytes than requested
+// and shifting the aligned pointer relative to the allocated memory. If
+// alignment is unspecified, the two pointers are equal.
+struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
+ using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
+
+ AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
+ bool useAlloca = false)
+ : LLVMLegalizationPattern<AllocOp>(dialect_, converter),
+ useAlloca(useAlloca) {}
+
+ PatternMatchResult match(Operation *op) const override {
+ MemRefType type = cast<AllocOp>(op).getType();
+ if (isSupportedMemRefType(type))
+ return matchSuccess();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(type, strides, offset);
+ if (failed(successStrides))
+ return matchFailure();
+
+ // Dynamic strides are ok if they can be deduced from dynamic sizes (which
+ // is guaranteed when succeeded(successStrides)). Dynamic offset however can
+ // never be alloc'ed.
+ if (offset == MemRefType::getDynamicStrideOrOffset())
+ return matchFailure();
+
+ return matchSuccess();
+ }
+
+ void rewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto allocOp = cast<AllocOp>(op);
+ MemRefType type = allocOp.getType();
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
+ SmallVector<Value, 4> sizes;
+ sizes.reserve(type.getRank());
+ unsigned i = 0;
+ for (int64_t s : type.getShape())
+ sizes.push_back(s == -1 ? operands[i++]
+ : createIndexConstant(rewriter, loc, s));
+ if (sizes.empty())
+ sizes.push_back(createIndexConstant(rewriter, loc, 1));
+
+ // Compute the total number of memref elements.
+ Value cumulativeSize = sizes.front();
+ for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+ cumulativeSize = rewriter.create<LLVM::MulOp>(
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSize, sizes[i]});
+
+ // Compute the size of an individual element. This emits the MLIR equivalent
+ // of the following sizeof(...) implementation in LLVM IR:
+ // %0 = getelementptr %elementType* null, %indexType 1
+ // %1 = ptrtoint %elementType* %0 to %indexType
+ // which is a common pattern of getting the size of a type in bytes.
+ auto elementType = type.getElementType();
+ auto convertedPtrType =
+ lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
+ auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
+ auto one = createIndexConstant(rewriter, loc, 1);
+ auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
+ ArrayRef<Value>{nullPtr, one});
+ auto elementSize =
+ rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+ cumulativeSize = rewriter.create<LLVM::MulOp>(
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSize, elementSize});
+
+ // Allocate the underlying buffer and store a pointer to it in the MemRef
+ // descriptor.
+ Value allocated = nullptr;
+ int alignment = 0;
+ Value alignmentValue = nullptr;
+ if (auto alignAttr = allocOp.alignment())
+ alignment = alignAttr.getValue().getSExtValue();
+
+ if (useAlloca) {
+ allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
+ cumulativeSize, alignment);
+ } else {
+ // Insert the `malloc` declaration if it is not already present.
+ auto module = op->getParentOfType<ModuleOp>();
+ auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
+ if (!mallocFunc) {
+ OpBuilder moduleBuilder(
+ op->getParentOfType<ModuleOp>().getBodyRegion());
+ mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
+ rewriter.getUnknownLoc(), "malloc",
+ LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
+ /*isVarArg=*/false));
+ }
+ if (alignment != 0) {
+ alignmentValue = createIndexConstant(rewriter, loc, alignment);
+ cumulativeSize = rewriter.create<LLVM::SubOp>(
+ loc,
+ rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue),
+ one);
+ }
+ allocated = rewriter
+ .create<LLVM::CallOp>(
+ loc, getVoidPtrType(),
+ rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize)
+ .getResult(0);
+ }
+
+ auto structElementType = lowering.convertType(elementType);
+ auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
+ type.getMemorySpace());
+ Value bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
+ loc, elementPtrType, ArrayRef<Value>(allocated));
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(type, strides, offset);
+ assert(succeeded(successStrides) && "unexpected non-strided memref");
+ (void)successStrides;
+ assert(offset != MemRefType::getDynamicStrideOrOffset() &&
+ "unexpected dynamic offset");
+
+ // 0-D memref corner case: they have size 1 ...
+ assert(((type.getRank() == 0 && strides.empty() && sizes.size() == 1) ||
+ (strides.size() == sizes.size())) &&
+ "unexpected number of strides");
+
+ // Create the MemRef descriptor.
+ auto structType = lowering.convertType(type);
+ auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
+ // Field 1: Allocated pointer, used for malloc/free.
+ memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
+
+ // Field 2: Actual aligned pointer to payload.
+ Value bitcastAligned = bitcastAllocated;
+ if (!useAlloca && alignment != 0) {
+ assert(alignmentValue);
+ // offset = (align - (ptr % align))% align
+ Value intVal = rewriter.create<LLVM::PtrToIntOp>(
+ loc, this->getIndexType(), allocated);
+ Value ptrModAlign =
+ rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
+ Value subbed =
+ rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
+ Value offset = rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
+ Value aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
+ allocated, offset);
+ bitcastAligned = rewriter.create<LLVM::BitcastOp>(
+ loc, elementPtrType, ArrayRef<Value>(aligned));
+ }
+ memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
+
+ // Field 3: Offset in aligned pointer.
+ memRefDescriptor.setOffset(rewriter, loc,
+ createIndexConstant(rewriter, loc, offset));
+
+ if (type.getRank() == 0)
+ // No size/stride descriptor in memref, return the descriptor value.
+ return rewriter.replaceOp(op, {memRefDescriptor});
+
+ // Fields 4 and 5: Sizes and strides of the strided MemRef.
+ // Store all sizes in the descriptor. Only dynamic sizes are passed in as
+ // operands to AllocOp.
+ Value runningStride = nullptr;
+ // Iterate strides in reverse order, compute runningStride and strideValues.
+ auto nStrides = strides.size();
+ SmallVector<Value, 4> strideValues(nStrides, nullptr);
+ for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
+ int64_t index = nStrides - 1 - indexedStride.index();
+ if (strides[index] == MemRefType::getDynamicStrideOrOffset())
+ // Identity layout map is enforced in the match function, so we compute:
+ // `runningStride *= sizes[index]`
+ runningStride =
+ runningStride
+ ? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index])
+ : createIndexConstant(rewriter, loc, 1);
+ else
+ runningStride = createIndexConstant(rewriter, loc, strides[index]);
+ strideValues[index] = runningStride;
+ }
+ // Fill size and stride descriptors in memref.
+ for (auto indexedSize : llvm::enumerate(sizes)) {
+ int64_t index = indexedSize.index();
+ memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
+ memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
+ }
+
+ // Return the final value of the descriptor.
+ rewriter.replaceOp(op, {memRefDescriptor});
+ }
+
+ bool useAlloca;
+};
+
+// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
+// passes the pointer to the MemRef across function boundaries.
+template <typename CallOpType>
+struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
+ using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern;
+ using Super = CallOpInterfaceLowering<CallOpType>;
+ using Base = LLVMLegalizationPattern<CallOpType>;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ OperandAdaptor<CallOpType> transformed(operands);
+ auto callOp = cast<CallOpType>(op);
+
+ // Pack the result types into a struct.
+ Type packedResult;
+ unsigned numResults = callOp.getNumResults();
+ auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
+
+ for (Type resType : resultTypes) {
+ assert(!resType.isa<UnrankedMemRefType>() &&
+ "Returning unranked memref is not supported. Pass result as an"
+ "argument instead.");
+ (void)resType;
+ }
+
+ if (numResults != 0) {
+ if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
+ return this->matchFailure();
+ }
+
+ auto promoted = this->lowering.promoteMemRefDescriptors(
+ op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter);
+ auto newOp = rewriter.create<LLVM::CallOp>(op->getLoc(), packedResult,
+ promoted, op->getAttrs());
+
+ // If < 2 results, packing did not do anything and we can just return.
+ if (numResults < 2) {
+ rewriter.replaceOp(op, newOp.getResults());
+ return this->matchSuccess();
+ }
+
+ // Otherwise, it had been converted to an operation producing a structure.
+ // Extract individual results from the structure and return them as list.
+ // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around
+ // a particular interaction between MemRefType and CallOp lowering. Find a
+ // way to avoid special casing.
+ SmallVector<Value, 4> results;
+ results.reserve(numResults);
+ for (unsigned i = 0; i < numResults; ++i) {
+ auto type = this->lowering.convertType(op->getResult(i)->getType());
+ results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ op->getLoc(), type, newOp.getOperation()->getResult(0),
+ rewriter.getI64ArrayAttr(i)));
+ }
+ rewriter.replaceOp(op, results);
+
+ return this->matchSuccess();
+ }
+};
+
+struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
+ using Super::Super;
+};
+
+struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
+ using Super::Super;
+};
+
+// A `dealloc` is converted into a call to `free` on the underlying data buffer.
+// The memref descriptor being an SSA value, there is no need to clean it up
+// in any way.
+struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
+ using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
+
+ DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
+ bool useAlloca = false)
+ : LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
+ useAlloca(useAlloca) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (useAlloca)
+ return rewriter.eraseOp(op), matchSuccess();
+
+ assert(operands.size() == 1 && "dealloc takes one operand");
+ OperandAdaptor<DeallocOp> transformed(operands);
+
+ // Insert the `free` declaration if it is not already present.
+ auto freeFunc =
+ op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
+ if (!freeFunc) {
+ OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
+ freeFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
+ rewriter.getUnknownLoc(), "free",
+ LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
+ /*isVarArg=*/false));
+ }
+
+ MemRefDescriptor memref(transformed.memref());
+ Value casted = rewriter.create<LLVM::BitcastOp>(
+ op->getLoc(), getVoidPtrType(),
+ memref.allocatedPtr(rewriter, op->getLoc()));
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
+ return matchSuccess();
+ }
+
+ bool useAlloca;
+};
+
+// A `tanh` is converted into a call to the `tanh` function.
+struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
+ using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ using LLVMFuncOpT = LLVM::LLVMFuncOp;
+ using LLVMTypeT = LLVM::LLVMType;
+
+ OperandAdaptor<TanhOp> transformed(operands);
+ LLVMTypeT operandType =
+ transformed.operand()->getType().dyn_cast_or_null<LLVM::LLVMType>();
+
+ if (!operandType)
+ return matchFailure();
+
+ std::string functionName;
+ if (operandType.isFloatTy())
+ functionName = "tanhf";
+ else if (operandType.isDoubleTy())
+ functionName = "tanh";
+ else
+ return matchFailure();
+
+ // Get a reference to the tanh function, inserting it if necessary.
+ Operation *tanhFunc =
+ SymbolTable::lookupNearestSymbolFrom(op, functionName);
+
+ LLVMFuncOpT tanhLLVMFunc;
+ if (tanhFunc) {
+ tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc);
+ } else {
+ PatternRewriter::InsertionGuard insertGuard(rewriter);
+ auto module = op->getParentOfType<ModuleOp>();
+ rewriter.setInsertionPointToStart(module.getBody());
+ tanhLLVMFunc = rewriter.create<LLVMFuncOpT>(
+ module.getLoc(), functionName,
+ LLVMTypeT::getFunctionTy(operandType, operandType,
+ /*isVarArg=*/false));
+ }
+
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
+ transformed.operand());
+ return matchSuccess();
+ }
+};
+
+struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
+ using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult match(Operation *op) const override {
+ auto memRefCastOp = cast<MemRefCastOp>(op);
+ Type srcType = memRefCastOp.getOperand()->getType();
+ Type dstType = memRefCastOp.getType();
+
+ if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
+ MemRefType sourceType =
+ memRefCastOp.getOperand()->getType().cast<MemRefType>();
+ MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
+ return (isSupportedMemRefType(targetType) &&
+ isSupportedMemRefType(sourceType))
+ ? matchSuccess()
+ : matchFailure();
+ }
+
+ // At least one of the operands is unranked type
+ assert(srcType.isa<UnrankedMemRefType>() ||
+ dstType.isa<UnrankedMemRefType>());
+
+ // Unranked to unranked cast is disallowed
+ return !(srcType.isa<UnrankedMemRefType>() &&
+ dstType.isa<UnrankedMemRefType>())
+ ? matchSuccess()
+ : matchFailure();
+ }
+
+ void rewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memRefCastOp = cast<MemRefCastOp>(op);
+ OperandAdaptor<MemRefCastOp> transformed(operands);
+
+ auto srcType = memRefCastOp.getOperand()->getType();
+ auto dstType = memRefCastOp.getType();
+ auto targetStructType = lowering.convertType(memRefCastOp.getType());
+ auto loc = op->getLoc();
+
+ if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
+ // memref_cast is defined for source and destination memref types with the
+ // same element type, same mappings, same address space and same rank.
+ // Therefore a simple bitcast suffices. If not it is undefined behavior.
+ rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
+ transformed.source());
+ } else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
+ // Casting ranked to unranked memref type
+ // Set the rank in the destination from the memref type
+ // Allocate space on the stack and copy the src memref descriptor
+ // Set the ptr in the destination to the stack space
+ auto srcMemRefType = srcType.cast<MemRefType>();
+ int64_t rank = srcMemRefType.getRank();
+ // ptr = AllocaOp sizeof(MemRefDescriptor)
+ auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(),
+ rewriter);
+ // voidptr = BitCastOp srcType* to void*
+ auto voidPtr =
+ rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
+ .getResult();
+ // rank = ConstantOp srcRank
+ auto rankVal = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(rewriter.getIntegerType(64)),
+ rewriter.getI64IntegerAttr(rank));
+ // undef = UndefOp
+ UnrankedMemRefDescriptor memRefDesc =
+ UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
+ // d1 = InsertValueOp undef, rank, 0
+ memRefDesc.setRank(rewriter, loc, rankVal);
+ // d2 = InsertValueOp d1, voidptr, 1
+ memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
+ rewriter.replaceOp(op, (Value)memRefDesc);
+
+ } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
+ // Casting from unranked type to ranked.
+ // The operation is assumed to be doing a correct cast. If the destination
+ // type mismatches the unranked the type, it is undefined behavior.
+ UnrankedMemRefDescriptor memRefDesc(transformed.source());
+ // ptr = ExtractValueOp src, 1
+ auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
+ // castPtr = BitCastOp i8* to structTy*
+ auto castPtr =
+ rewriter
+ .create<LLVM::BitcastOp>(
+ loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
+ ptr)
+ .getResult();
+ // struct = LoadOp castPtr
+ auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
+ rewriter.replaceOp(op, loadOp.getResult());
+ } else {
+ llvm_unreachable("Unsuppored unranked memref to unranked memref cast");
+ }
+ }
+};
+
+// A `dim` is converted to a constant for static sizes and to an access to the
+// size stored in the memref descriptor for dynamic sizes.
+struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
+ using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dimOp = cast<DimOp>(op);
+ OperandAdaptor<DimOp> transformed(operands);
+ MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
+
+ auto shape = type.getShape();
+ int64_t index = dimOp.getIndex();
+ // Extract dynamic size from the memref descriptor.
+ if (ShapedType::isDynamic(shape[index]))
+ rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
+ .size(rewriter, op->getLoc(), index)});
+ else
+ // Use constant for static size.
+ rewriter.replaceOp(
+ op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
+ return matchSuccess();
+ }
+};
+
+// Common base for load and store operations on MemRefs. Restricts the match
+// to supported MemRef types. Provides functionality to emit code accessing a
+// specific element of the underlying data buffer.
+template <typename Derived>
+struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
+ using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
+ using Base = LoadStoreOpLowering<Derived>;
+
+ PatternMatchResult match(Operation *op) const override {
+ MemRefType type = cast<Derived>(op).getMemRefType();
+ return isSupportedMemRefType(type) ? this->matchSuccess()
+ : this->matchFailure();
+ }
+
+ // Given subscript indices and array sizes in row-major order,
+ // i_n, i_{n-1}, ..., i_1
+ // s_n, s_{n-1}, ..., s_1
+ // obtain a value that corresponds to the linearized subscript
+ // \sum_k i_k * \prod_{j=1}^{k-1} s_j
+ // by accumulating the running linearized value.
+ // Note that `indices` and `allocSizes` are passed in the same order as they
+ // appear in load/store operations and memref type declarations.
+ Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
+ ArrayRef<Value> indices,
+ ArrayRef<Value> allocSizes) const {
+ assert(indices.size() == allocSizes.size() &&
+ "mismatching number of indices and allocation sizes");
+ assert(!indices.empty() && "cannot linearize a 0-dimensional access");
+
+ Value linearized = indices.front();
+ for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
+ linearized = builder.create<LLVM::MulOp>(
+ loc, this->getIndexType(),
+ ArrayRef<Value>{linearized, allocSizes[i]});
+ linearized = builder.create<LLVM::AddOp>(
+ loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
+ }
+ return linearized;
+ }
+
+ // This is a strided getElementPtr variant that linearizes subscripts as:
+ // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
+ Value getStridedElementPtr(Location loc, Type elementTypePtr,
+ Value descriptor, ArrayRef<Value> indices,
+ ArrayRef<int64_t> strides, int64_t offset,
+ ConversionPatternRewriter &rewriter) const {
+ MemRefDescriptor memRefDescriptor(descriptor);
+
+ Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+ Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.offset(rewriter, loc)
+ : this->createIndexConstant(rewriter, loc, offset);
+
+ for (int i = 0, e = indices.size(); i < e; ++i) {
+ Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.stride(rewriter, loc, i)
+ : this->createIndexConstant(rewriter, loc, strides[i]);
+ Value additionalOffset =
+ rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
+ offsetValue =
+ rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
+ }
+ return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
+ }
+
+ Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
+ ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
+ llvm::Module &module) const {
+ LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(type, strides, offset);
+ assert(succeeded(successStrides) && "unexpected non-strided memref");
+ (void)successStrides;
+ return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
+ offset, rewriter);
+ }
+};
+
+// Load operation is lowered to obtaining a pointer to the indexed element
+// and loading it.
+struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
+ using Base::Base;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loadOp = cast<LoadOp>(op);
+ OperandAdaptor<LoadOp> transformed(operands);
+ auto type = loadOp.getMemRefType();
+
+ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
+ return matchSuccess();
+ }
+};
+
+// Store operation is lowered to obtaining a pointer to the indexed element,
+// and storing the given value to it.
+struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
+ using Base::Base;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto type = cast<StoreOp>(op).getMemRefType();
+ OperandAdaptor<StoreOp> transformed(operands);
+
+ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
+ dataPtr);
+ return matchSuccess();
+ }
+};
+
+// The prefetch operation is lowered in a way similar to the load operation
+// except that the llvm.prefetch operation is used for replacement.
+struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
+ using Base::Base;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto prefetchOp = cast<PrefetchOp>(op);
+ OperandAdaptor<PrefetchOp> transformed(operands);
+ auto type = prefetchOp.getMemRefType();
+
+ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
+
+ // Replace with llvm.prefetch.
+ auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32));
+ auto isWrite = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmI32Type,
+ rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
+ auto localityHint = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmI32Type,
+ rewriter.getI32IntegerAttr(prefetchOp.localityHint().getZExtValue()));
+ auto isData = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmI32Type,
+ rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
+
+ rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
+ localityHint, isData);
+ return matchSuccess();
+ }
+};
+
+// The lowering of index_cast becomes an integer conversion since index becomes
+// an integer. If the bit width of the source and target integer types is the
+// same, just erase the cast. If the target type is wider, sign-extend the
+// value, otherwise truncate it.
+struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
+ using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ IndexCastOpOperandAdaptor transformed(operands);
+ auto indexCastOp = cast<IndexCastOp>(op);
+
+ auto targetType =
+ this->lowering.convertType(indexCastOp.getResult()->getType())
+ .cast<LLVM::LLVMType>();
+ auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>();
+ unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
+ unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
+
+ if (targetBits == sourceBits)
+ rewriter.replaceOp(op, transformed.in());
+ else if (targetBits < sourceBits)
+ rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
+ transformed.in());
+ else
+ rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
+ transformed.in());
+ return matchSuccess();
+ }
+};
+
+// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two
+// enums share the numerical values so just cast.
+template <typename LLVMPredType, typename StdPredType>
+static LLVMPredType convertCmpPredicate(StdPredType pred) {
+ return static_cast<LLVMPredType>(pred);
+}
+
+struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
+ using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto cmpiOp = cast<CmpIOp>(op);
+ CmpIOpOperandAdaptor transformed(operands);
+
+ rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
+ op, lowering.convertType(cmpiOp.getResult()->getType()),
+ rewriter.getI64IntegerAttr(static_cast<int64_t>(
+ convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
+ transformed.lhs(), transformed.rhs());
+
+ return matchSuccess();
+ }
+};
+
+struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
+ using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto cmpfOp = cast<CmpFOp>(op);
+ CmpFOpOperandAdaptor transformed(operands);
+
+ rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
+ op, lowering.convertType(cmpfOp.getResult()->getType()),
+ rewriter.getI64IntegerAttr(static_cast<int64_t>(
+ convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
+ transformed.lhs(), transformed.rhs());
+
+ return matchSuccess();
+ }
+};
+
+struct SIToFPLowering
+ : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
+ using Super::Super;
+};
+
+struct FPExtLowering : public OneToOneLLVMOpLowering<FPExtOp, LLVM::FPExtOp> {
+ using Super::Super;
+};
+
+struct FPTruncLowering
+ : public OneToOneLLVMOpLowering<FPTruncOp, LLVM::FPTruncOp> {
+ using Super::Super;
+};
+
+struct SignExtendIOpLowering
+ : public OneToOneLLVMOpLowering<SignExtendIOp, LLVM::SExtOp> {
+ using Super::Super;
+};
+
+struct TruncateIOpLowering
+ : public OneToOneLLVMOpLowering<TruncateIOp, LLVM::TruncOp> {
+ using Super::Super;
+};
+
+struct ZeroExtendIOpLowering
+ : public OneToOneLLVMOpLowering<ZeroExtendIOp, LLVM::ZExtOp> {
+ using Super::Super;
+};
+
+// Base class for LLVM IR lowering terminator operations with successors.
+template <typename SourceOp, typename TargetOp>
+struct OneToOneLLVMTerminatorLowering
+ : public LLVMLegalizationPattern<SourceOp> {
+ using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+ using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
+ ArrayRef<Block *> destinations,
+ ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<ValueRange, 2> operandRanges(operands.begin(), operands.end());
+ rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
+ operandRanges, op->getAttrs());
+ return this->matchSuccess();
+ }
+};
+
+// Special lowering pattern for `ReturnOps`. Unlike all other operations,
+// `ReturnOp` interacts with the function signature and must have as many
+// operands as the function has return values. Because in LLVM IR, functions
+// can only return 0 or 1 value, we pack multiple values into a structure type.
+// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
+// necessary before returning it
+struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
+ using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ unsigned numArguments = op->getNumOperands();
+
+ // If ReturnOp has 0 or 1 operand, create it and return immediately.
+ if (numArguments == 0) {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, ArrayRef<Value>(), ArrayRef<Block *>(), op->getAttrs());
+ return matchSuccess();
+ }
+ if (numArguments == 1) {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, ArrayRef<Value>(operands.front()), ArrayRef<Block *>(),
+ op->getAttrs());
+ return matchSuccess();
+ }
+
+ // Otherwise, we need to pack the arguments into an LLVM struct type before
+ // returning.
+ auto packedType =
+ lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
+
+ Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
+ for (unsigned i = 0; i < numArguments; ++i) {
+ packed = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), packedType, packed, operands[i],
+ rewriter.getI64ArrayAttr(i));
+ }
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, llvm::makeArrayRef(packed), ArrayRef<Block *>(), op->getAttrs());
+ return matchSuccess();
+ }
+};
+
+// FIXME: this should be tablegen'ed as well.
+struct BranchOpLowering
+ : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
+ using Super::Super;
+};
+struct CondBranchOpLowering
+ : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
+ using Super::Super;
+};
+
+// The Splat operation is lowered to an insertelement + a shufflevector
+// operation. Splat to only 1-d vector result types are lowered.
+struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
+ using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splatOp = cast<SplatOp>(op);
+ VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
+ if (!resultType || resultType.getRank() != 1)
+ return matchFailure();
+
+ // First insert it into an undef vector so we can shuffle it.
+ auto vectorType = lowering.convertType(splatOp.getType());
+ Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
+ rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+
+ auto v = rewriter.create<LLVM::InsertElementOp>(
+ op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
+
+ int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
+ SmallVector<int32_t, 4> zeroValues(width, 0);
+
+ // Shuffle the value across the desired number of elements.
+ ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
+ return matchSuccess();
+ }
+};
+
+// The Splat operation is lowered to an insertelement + a shufflevector
+// operation. Splat to only 2+-d vector result types are lowered by the
+// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
+struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
+ using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splatOp = cast<SplatOp>(op);
+ OperandAdaptor<SplatOp> adaptor(operands);
+ VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
+ if (!resultType || resultType.getRank() == 1)
+ return matchFailure();
+
+ // First insert it into an undef vector so we can shuffle it.
+ auto loc = op->getLoc();
+ auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering);
+ auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
+ auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+ if (!llvmArrayTy || !llvmVectorTy)
+ return matchFailure();
+
+ // Construct returned value.
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+
+ // Construct a 1-D vector with the splatted value that we insert in all the
+ // places within the returned descriptor.
+ Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(rewriter.getIntegerType(32)),
+ rewriter.getZeroAttr(rewriter.getIntegerType(32)));
+ Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
+ adaptor.input(), zero);
+
+ // Shuffle the value across the desired number of elements.
+ int64_t width = resultType.getDimSize(resultType.getRank() - 1);
+ SmallVector<int32_t, 4> zeroValues(width, 0);
+ ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
+ v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
+
+ // Iterate of linear index, convert to coords space and insert splatted 1-D
+ // vector in each position.
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
+ position);
+ });
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
+/// Conversion pattern that transforms a subview op into:
+/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
+/// 2. Updates to the descriptor to introduce the data ptr, offset, size
+/// and stride.
+/// The subview op is replaced by the descriptor.
+struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
+ using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto viewOp = cast<SubViewOp>(op);
+ // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support
+ // having multiple variadic operands where each operand can have different
+ // number of entries, clean all of this up.
+ SmallVector<Value, 2> dynamicOffsets(
+ std::next(operands.begin()),
+ std::next(operands.begin(), 1 + viewOp.getNumOffsets()));
+ SmallVector<Value, 2> dynamicSizes(
+ std::next(operands.begin(), 1 + viewOp.getNumOffsets()),
+ std::next(operands.begin(),
+ 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()));
+ SmallVector<Value, 2> dynamicStrides(
+ std::next(operands.begin(),
+ 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()),
+ operands.end());
+
+ auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
+ auto sourceElementTy =
+ lowering.convertType(sourceMemRefType.getElementType())
+ .dyn_cast_or_null<LLVM::LLVMType>();
+
+ auto viewMemRefType = viewOp.getType();
+ auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+ auto targetDescTy =
+ lowering.convertType(viewMemRefType).dyn_cast_or_null<LLVM::LLVMType>();
+ if (!sourceElementTy || !targetDescTy)
+ return matchFailure();
+
+ // Currently, only rank > 0 and full or no operands are supported. Fail to
+ // convert otherwise.
+ unsigned rank = sourceMemRefType.getRank();
+ if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
+ (!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
+ (!dynamicStrides.empty() && rank != dynamicStrides.size()))
+ return matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+ if (failed(successStrides))
+ return matchFailure();
+
+ // Create the descriptor.
+ MemRefDescriptor sourceMemRef(operands.front());
+ auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
+
+ // Copy the buffer pointer from the old descriptor to the new one.
+ Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(), extracted);
+ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
+
+ extracted = sourceMemRef.alignedPtr(rewriter, loc);
+ bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(), extracted);
+ targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
+
+ // Extract strides needed to compute offset.
+ SmallVector<Value, 4> strideValues;
+ strideValues.reserve(viewMemRefType.getRank());
+ for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
+ strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
+
+ // Fill in missing dynamic sizes.
+ auto llvmIndexType = lowering.convertType(rewriter.getIndexType());
+ if (dynamicSizes.empty()) {
+ dynamicSizes.reserve(viewMemRefType.getRank());
+ auto shape = viewMemRefType.getShape();
+ for (auto extent : shape) {
+ dynamicSizes.push_back(rewriter.create<LLVM::ConstantOp>(
+ loc, llvmIndexType, rewriter.getI64IntegerAttr(extent)));
+ }
+ }
+
+ // Offset.
+ Value baseOffset = sourceMemRef.offset(rewriter, loc);
+ for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
+ Value min = dynamicOffsets[i];
+ baseOffset = rewriter.create<LLVM::AddOp>(
+ loc, baseOffset,
+ rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
+ }
+ targetMemRef.setOffset(rewriter, loc, baseOffset);
+
+ // Update sizes and strides.
+ for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+ targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]);
+ Value newStride;
+ if (dynamicStrides.empty())
+ newStride = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
+ else
+ newStride = rewriter.create<LLVM::MulOp>(loc, dynamicStrides[i],
+ strideValues[i]);
+ targetMemRef.setStride(rewriter, loc, i, newStride);
+ }
+
+ rewriter.replaceOp(op, {targetMemRef});
+ return matchSuccess();
+ }
+};
+
+/// Conversion pattern that transforms a op into:
+/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
+/// 2. Updates to the descriptor to introduce the data ptr, offset, size
+/// and stride.
+/// The view op is replaced by the descriptor.
+struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
+ using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern;
+
+ // Build and return the value for the idx^th shape dimension, either by
+ // returning the constant shape dimension or counting the proper dynamic size.
+ Value getSize(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> shape, ArrayRef<Value> dynamicSizes,
+ unsigned idx) const {
+ assert(idx < shape.size());
+ if (!ShapedType::isDynamic(shape[idx]))
+ return createIndexConstant(rewriter, loc, shape[idx]);
+ // Count the number of dynamic dims in range [0, idx]
+ unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
+ return ShapedType::isDynamic(v);
+ });
+ return dynamicSizes[nDynamic];
+ }
+
+ // Build and return the idx^th stride, either by returning the constant stride
+ // or by computing the dynamic stride from the current `runningStride` and
+ // `nextSize`. The caller should keep a running stride and update it with the
+ // result returned by this function.
+ Value getStride(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> strides, Value nextSize,
+ Value runningStride, unsigned idx) const {
+ assert(idx < strides.size());
+ if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
+ return createIndexConstant(rewriter, loc, strides[idx]);
+ if (nextSize)
+ return runningStride
+ ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
+ : nextSize;
+ assert(!runningStride);
+ return createIndexConstant(rewriter, loc, 1);
+ }
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto viewOp = cast<ViewOp>(op);
+ ViewOpOperandAdaptor adaptor(operands);
+
+ auto viewMemRefType = viewOp.getType();
+ auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+ auto targetDescTy =
+ lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
+ if (!targetDescTy)
+ return op->emitWarning("Target descriptor type not converted to LLVM"),
+ matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+ if (failed(successStrides))
+ return op->emitWarning("cannot cast to non-strided shape"),
+ matchFailure();
+
+ // Create the descriptor.
+ MemRefDescriptor sourceMemRef(adaptor.source());
+ auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
+
+ // Field 1: Copy the allocated pointer, used for malloc/free.
+ Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(), extracted);
+ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
+
+ // Field 2: Copy the actual aligned pointer to payload.
+ extracted = sourceMemRef.alignedPtr(rewriter, loc);
+ bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(), extracted);
+ targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
+
+ // Field 3: Copy the offset in aligned pointer.
+ unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
+ (void)numDynamicSizes;
+ bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset();
+ auto sizeAndOffsetOperands = adaptor.operands();
+ assert(llvm::size(sizeAndOffsetOperands) ==
+ numDynamicSizes + (hasDynamicOffset ? 1 : 0));
+ Value baseOffset = !hasDynamicOffset
+ ? createIndexConstant(rewriter, loc, offset)
+ // TODO(ntv): better adaptor.
+ : sizeAndOffsetOperands.front();
+ targetMemRef.setOffset(rewriter, loc, baseOffset);
+
+ // Early exit for 0-D corner case.
+ if (viewMemRefType.getRank() == 0)
+ return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
+
+ // Fields 4 and 5: Update sizes and strides.
+ if (strides.back() != 1)
+ return op->emitWarning("cannot cast to non-contiguous shape"),
+ matchFailure();
+ Value stride = nullptr, nextSize = nullptr;
+ // Drop the dynamic stride from the operand list, if present.
+ ArrayRef<Value> sizeOperands(sizeAndOffsetOperands);
+ if (hasDynamicOffset)
+ sizeOperands = sizeOperands.drop_front();
+ for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+ // Update size.
+ Value size =
+ getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i);
+ targetMemRef.setSize(rewriter, loc, i, size);
+ // Update stride.
+ stride = getStride(rewriter, loc, strides, nextSize, stride, i);
+ targetMemRef.setStride(rewriter, loc, i, stride);
+ nextSize = size;
+ }
+
+ rewriter.replaceOp(op, {targetMemRef});
+ return matchSuccess();
+ }
+};
+
+} // namespace
+
+static void ensureDistinctSuccessors(Block &bb) {
+ auto *terminator = bb.getTerminator();
+
+ // Find repeated successors with arguments.
+ llvm::SmallDenseMap<Block *, SmallVector<int, 4>> successorPositions;
+ for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
+ Block *successor = terminator->getSuccessor(i);
+ // Blocks with no arguments are safe even if they appear multiple times
+ // because they don't need PHI nodes.
+ if (successor->getNumArguments() == 0)
+ continue;
+ successorPositions[successor].push_back(i);
+ }
+
+ // If a successor appears for the second or more time in the terminator,
+ // create a new dummy block that unconditionally branches to the original
+ // destination, and retarget the terminator to branch to this new block.
+ // There is no need to pass arguments to the dummy block because it will be
+ // dominated by the original block and can therefore use any values defined in
+ // the original block.
+ for (const auto &successor : successorPositions) {
+ const auto &positions = successor.second;
+ // Start from the second occurrence of a block in the successor list.
+ for (auto position = std::next(positions.begin()), end = positions.end();
+ position != end; ++position) {
+ auto *dummyBlock = new Block();
+ bb.getParent()->push_back(dummyBlock);
+ auto builder = OpBuilder(dummyBlock);
+ SmallVector<Value, 8> operands(
+ terminator->getSuccessorOperands(*position));
+ builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
+ terminator->setSuccessor(dummyBlock, *position);
+ for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e;
+ ++i)
+ terminator->eraseSuccessorOperand(*position, i);
+ }
+ }
+}
+
+void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
+ for (auto f : m.getOps<FuncOp>()) {
+ for (auto &bb : f.getBlocks()) {
+ ::ensureDistinctSuccessors(bb);
+ }
+ }
+}
+
+/// Collect a set of patterns to convert from the Standard dialect to LLVM.
+void mlir::populateStdToLLVMNonMemoryConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ // FIXME: this should be tablegen'ed
+ // clang-format off
+ patterns.insert<
+ AbsFOpLowering,
+ AddFOpLowering,
+ AddIOpLowering,
+ AndOpLowering,
+ BranchOpLowering,
+ CallIndirectOpLowering,
+ CallOpLowering,
+ CeilFOpLowering,
+ CmpFOpLowering,
+ CmpIOpLowering,
+ CondBranchOpLowering,
+ CopySignOpLowering,
+ CosOpLowering,
+ ConstLLVMOpLowering,
+ DivFOpLowering,
+ ExpOpLowering,
+ LogOpLowering,
+ Log10OpLowering,
+ Log2OpLowering,
+ FPExtLowering,
+ FPTruncLowering,
+ IndexCastOpLowering,
+ MulFOpLowering,
+ MulIOpLowering,
+ NegFOpLowering,
+ OrOpLowering,
+ PrefetchOpLowering,
+ RemFOpLowering,
+ ReturnOpLowering,
+ SIToFPLowering,
+ SelectOpLowering,
+ ShiftLeftOpLowering,
+ SignExtendIOpLowering,
+ SignedDivIOpLowering,
+ SignedRemIOpLowering,
+ SignedShiftRightOpLowering,
+ SplatOpLowering,
+ SplatNdOpLowering,
+ SubFOpLowering,
+ SubIOpLowering,
+ TanhOpLowering,
+ TruncateIOpLowering,
+ UnsignedDivIOpLowering,
+ UnsignedRemIOpLowering,
+ UnsignedShiftRightOpLowering,
+ XOrOpLowering,
+ ZeroExtendIOpLowering>(*converter.getDialect(), converter);
+ // clang-format on
+}
+
+void mlir::populateStdToLLVMMemoryConversionPatters(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ // clang-format off
+ patterns.insert<
+ DimOpLowering,
+ FuncOpConversion,
+ LoadOpLowering,
+ MemRefCastOpLowering,
+ StoreOpLowering,
+ SubViewOpLowering,
+ ViewOpLowering>(*converter.getDialect(), converter);
+ patterns.insert<
+ AllocOpLowering,
+ DeallocOpLowering>(
+ *converter.getDialect(), converter, clUseAlloca.getValue());
+ // clang-format on
+}
+
+void mlir::populateStdToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
+ populateStdToLLVMMemoryConversionPatters(converter, patterns);
+}
+
+// Convert types using the stored LLVM IR module.
+Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
+
+// Create an LLVM IR structure type if there is more than one result.
+Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
+ assert(!types.empty() && "expected non-empty list of type");
+
+ if (types.size() == 1)
+ return convertType(types.front());
+
+ SmallVector<LLVM::LLVMType, 8> resultTypes;
+ resultTypes.reserve(types.size());
+ for (auto t : types) {
+ auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
+ if (!converted)
+ return {};
+ resultTypes.push_back(converted);
+ }
+
+ return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
+}
+
+Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
+ OpBuilder &builder) {
+ auto *context = builder.getContext();
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
+ auto indexType = IndexType::get(context);
+ // Alloca with proper alignment. We do not expect optimizations of this
+ // alloca op and so we omit allocating at the entry block.
+ auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo();
+ Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
+ IntegerAttr::get(indexType, 1));
+ Value allocated =
+ builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
+ // Store into the alloca'ed descriptor.
+ builder.create<LLVM::StoreOp>(loc, operand, allocated);
+ return allocated;
+}
+
+SmallVector<Value, 4>
+LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
+ ValueRange operands,
+ OpBuilder &builder) {
+ SmallVector<Value, 4> promotedOperands;
+ promotedOperands.reserve(operands.size());
+ for (auto it : llvm::zip(opOperands, operands)) {
+ auto operand = std::get<0>(it);
+ auto llvmOperand = std::get<1>(it);
+ if (!operand->getType().isa<MemRefType>() &&
+ !operand->getType().isa<UnrankedMemRefType>()) {
+ promotedOperands.push_back(operand);
+ continue;
+ }
+ promotedOperands.push_back(
+ promoteOneMemRefDescriptor(loc, llvmOperand, builder));
+ }
+ return promotedOperands;
+}
+
+/// Create an instance of LLVMTypeConverter in the given context.
+static std::unique_ptr<LLVMTypeConverter>
+makeStandardToLLVMTypeConverter(MLIRContext *context) {
+ return std::make_unique<LLVMTypeConverter>(context);
+}
+
+namespace {
+/// A pass converting MLIR operations into the LLVM IR dialect.
+struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
+ // By default, the patterns are those converting Standard operations to the
+ // LLVMIR dialect.
+ explicit LLVMLoweringPass(
+ bool useAlloca = false,
+ LLVMPatternListFiller patternListFiller =
+ populateStdToLLVMConversionPatterns,
+ LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
+ : patternListFiller(patternListFiller),
+ typeConverterMaker(converterBuilder) {}
+
+ // Run the dialect converter on the module.
+ void runOnModule() override {
+ if (!typeConverterMaker || !patternListFiller)
+ return signalPassFailure();
+
+ ModuleOp m = getModule();
+ LLVM::ensureDistinctSuccessors(m);
+ std::unique_ptr<LLVMTypeConverter> typeConverter =
+ typeConverterMaker(&getContext());
+ if (!typeConverter)
+ return signalPassFailure();
+
+ OwningRewritePatternList patterns;
+ populateLoopToStdConversionPatterns(patterns, m.getContext());
+ patternListFiller(*typeConverter, patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
+ signalPassFailure();
+ }
+
+ // Callback for creating a list of patterns. It is called every time in
+ // runOnModule since applyPartialConversion consumes the list.
+ LLVMPatternListFiller patternListFiller;
+
+ // Callback for creating an instance of type converter. The converter
+ // constructor needs an MLIRContext, which is not available until runOnModule.
+ LLVMTypeConverterMaker typeConverterMaker;
+};
+} // end namespace
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createLowerToLLVMPass(bool useAlloca) {
+ return std::make_unique<LLVMLoweringPass>(useAlloca);
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
+ LLVMTypeConverterMaker typeConverterMaker,
+ bool useAlloca) {
+ return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller,
+ typeConverterMaker);
+}
+
+static PassRegistration<LLVMLoweringPass>
+ pass("convert-std-to-llvm",
+ "Convert scalar and vector operations from the "
+ "Standard to the LLVM dialect",
+ [] {
+ return std::make_unique<LLVMLoweringPass>(
+ clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
+ makeStandardToLLVMTypeConverter);
+ });
diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
new file mode 100644
index 00000000000..fcced23a95e
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
+mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
+
+add_llvm_library(MLIRStandardToSPIRVTransforms
+ ConvertStandardToSPIRV.cpp
+ ConvertStandardToSPIRVPass.cpp
+ LegalizeStandardForSPIRV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+ )
+
+add_dependencies(MLIRStandardToSPIRVTransforms
+ MLIRStandardToSPIRVIncGen)
+
+target_link_libraries(MLIRStandardToSPIRVTransforms
+ MLIRIR
+ MLIRPass
+ MLIRSPIRV
+ MLIRSupport
+ MLIRTransformUtils
+ MLIRSPIRV
+ MLIRStandardOps
+ )
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
new file mode 100644
index 00000000000..a02dee4419a
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -0,0 +1,314 @@
+//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineMap.h"
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Convert constant operation with IndexType return to SPIR-V constant
+/// operation. Since IndexType is not used within SPIR-V dialect, this needs
+/// special handling to make sure the result type and the type of the value
+/// attribute are consistent.
+// TODO(ravishankarm) : This should be moved into DRR.
+class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
+public:
+ using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert compare operation to SPIR-V dialect.
+class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
+public:
+ using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert integer binary operations to SPIR-V operations. Cannot use
+/// tablegen for this. If the integer operation is on variables of IndexType,
+/// the type of the return value of the replacement operation differs from
+/// that of the replaced operation. This is not handled in tablegen-based
+/// pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
+template <typename StdOp, typename SPIRVOp>
+class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
+public:
+ using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType =
+ this->typeConverter.convertType(operation.getResult()->getType());
+ rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ operation, resultType, operands, ArrayRef<NamedAttribute>());
+ return this->matchSuccess();
+ }
+};
+
+/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
+/// IndexType while that of the replacement operation are of type i32. This is
+/// not supported in tablegen based pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
+class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
+public:
+ using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert return -> spv.Return.
+// TODO(ravishankarm) : This should be moved into DRR.
+class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
+public:
+ using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert select -> spv.Select
+// TODO(ravishankarm) : This should be moved into DRR.
+class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
+public:
+ using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
+ PatternMatchResult
+ matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Convert store -> spv.StoreOp. The operands of the replaced operation are
+/// of IndexType while that of the replacement operation are of type i32. This
+/// is not supported in tablegen based pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
+class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
+public:
+ using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for operation conversion
+//===----------------------------------------------------------------------===//
+
+/// Performs the index computation to get to the element pointed to by
+/// `indices` using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
+// MemRefType with AffineMap that has static strides. Handle dynamic strides
+spirv::AccessChainOp getElementPtr(OpBuilder &builder,
+ SPIRVTypeConverter &typeConverter,
+ Location loc, MemRefType origBaseType,
+ Value basePtr, ArrayRef<Value> indices) {
+ // Get base and offset of the MemRefType and verify they are static.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+ return nullptr;
+ }
+
+ auto indexType = typeConverter.getIndexType(builder.getContext());
+
+ Value ptrLoc = nullptr;
+ assert(indices.size() == strides.size());
+ for (auto index : enumerate(indices)) {
+ Value strideVal = builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+ Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ ptrLoc =
+ (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
+ : update);
+ }
+ SmallVector<Value, 2> linearizedIndices;
+ // Add a '0' at the start to index into the struct.
+ linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 0)));
+ linearizedIndices.push_back(ptrLoc);
+ return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+}
+
+//===----------------------------------------------------------------------===//
+// ConstantOp with index type.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
+ ConstantOp constIndexOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
+ return matchFailure();
+ }
+ // The attribute has index type which is not directly supported in
+ // SPIR-V. Get the integer value and create a new IntegerAttr.
+ auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
+ if (!constAttr) {
+ return matchFailure();
+ }
+
+ // Use the bitwidth set in the value attribute to decide the result type
+ // of the SPIR-V constant operation since SPIR-V does not support index
+ // types.
+ auto constVal = constAttr.getValue();
+ auto constValType = constAttr.getType().dyn_cast<IndexType>();
+ if (!constValType) {
+ return matchFailure();
+ }
+ auto spirvConstType =
+ typeConverter.convertType(constIndexOp.getResult()->getType());
+ auto spirvConstVal =
+ rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
+ spirvConstVal);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ CmpIOpOperandAdaptor cmpIOpOperands(operands);
+
+ switch (cmpIOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp) \
+ case cmpPredicate: \
+ rewriter.replaceOpWithNewOp<spirvOp>( \
+ cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \
+ cmpIOpOperands.rhs()); \
+ return matchSuccess();
+
+ DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
+ DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
+ DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
+ DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
+ DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
+ DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
+
+#undef DISPATCH
+
+ default:
+ break;
+ }
+ return matchFailure();
+}
+
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ LoadOpOperandAdaptor loadOperands(operands);
+ auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
+ loadOp.memref()->getType().cast<MemRefType>(),
+ loadOperands.memref(), loadOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
+ /*memory_access =*/nullptr,
+ /*alignment =*/nullptr);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (returnOp.getNumOperands()) {
+ return matchFailure();
+ }
+ rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ SelectOpOperandAdaptor selectOperands(operands);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
+ selectOperands.true_value(),
+ selectOperands.false_value());
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ StoreOpOperandAdaptor storeOperands(operands);
+ auto storePtr =
+ getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
+ storeOp.memref()->getType().cast<MemRefType>(),
+ storeOperands.memref(), storeOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
+ storeOperands.value(),
+ /*memory_access =*/nullptr,
+ /*alignment =*/nullptr);
+ return matchSuccess();
+}
+
+namespace {
+/// Import the Standard Ops to SPIR-V Patterns.
+#include "StandardToSPIRV.cpp.inc"
+} // namespace
+
+namespace mlir {
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ // Add patterns that lower operations into SPIR-V dialect.
+ populateWithGenerated(context, &patterns);
+ patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
+ IntegerOpConversion<AddIOp, spirv::IAddOp>,
+ IntegerOpConversion<MulIOp, spirv::IMulOp>,
+ IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
+ IntegerOpConversion<SignedRemIOp, spirv::SModOp>,
+ IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
+ ReturnOpConversion, SelectOpConversion, StoreOpConversion>(
+ context, typeConverter);
+}
+} // namespace mlir
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
new file mode 100644
index 00000000000..52456b6e46d
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -0,0 +1,89 @@
+//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard ops into the SPIR-V
+// ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// A simple pattern for rewriting function signature to convert arguments of
+/// functions to be of valid SPIR-V types.
+class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
+public:
+ using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// A pass converting MLIR Standard operations into the SPIR-V dialect.
+class ConvertStandardToSPIRVPass
+ : public ModulePass<ConvertStandardToSPIRVPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+PatternMatchResult
+FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto fnType = funcOp.getType();
+ if (fnType.getNumResults()) {
+ return matchFailure();
+ }
+
+ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ {
+ for (auto argType : enumerate(funcOp.getType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
+ }
+
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
+ });
+ return matchSuccess();
+}
+
+void ConvertStandardToSPIRVPass::runOnModule() {
+ OwningRewritePatternList patterns;
+ auto context = &getContext();
+ auto module = getModule();
+
+ SPIRVTypeConverter typeConverter;
+ populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+ patterns.insert<FuncOpConversion>(context, typeConverter);
+ ConversionTarget target(*(module.getContext()));
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
+
+ if (failed(applyPartialConversion(module, target, patterns))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() {
+ return std::make_unique<ConvertStandardToSPIRVPass>();
+}
+
+static PassRegistration<ConvertStandardToSPIRVPass>
+ pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
new file mode 100644
index 00000000000..a658356f76c
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -0,0 +1,181 @@
+//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes operations before the conversion to SPIR-V
+// dialect to handle ops that cannot be lowered directly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Merges subview operation with load operation.
+class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
+public:
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store operation.
+class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
+public:
+ using OpRewritePattern<StoreOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Utility functions for op legalization.
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+/// memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndices(Location loc, PatternRewriter &rewriter,
+ SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ // TODO: Aborting when the offsets are static. There might be a way to fold
+ // the subview op with load even if the offsets have been canonicalized
+ // away.
+ if (subViewOp.getNumOffsets() == 0)
+ return failure();
+
+ ValueRange opOffsets = subViewOp.offsets();
+ SmallVector<Value, 2> opStrides;
+ if (subViewOp.getNumStrides()) {
+ // If the strides are dynamic, get the stride operands.
+ opStrides = llvm::to_vector<2>(subViewOp.strides());
+ } else {
+ // When static, the stride operands can be retrieved by taking the strides
+ // of the result of the subview op, and dividing the strides of the base
+ // memref.
+ SmallVector<int64_t, 2> staticStrides;
+ if (failed(subViewOp.getStaticStrides(staticStrides))) {
+ return failure();
+ }
+ opStrides.reserve(opOffsets.size());
+ for (auto stride : staticStrides) {
+ auto constValAttr = rewriter.getIntegerAttr(
+ IndexType::get(rewriter.getContext()), stride);
+ opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr));
+ }
+ }
+ assert(opOffsets.size() == opStrides.size());
+
+ // New indices for the load are the current indices * subview_stride +
+ // subview_offset.
+ assert(indices.size() == opStrides.size());
+ sourceIndices.resize(indices.size());
+ for (auto index : llvm::enumerate(indices)) {
+ auto offset = opOffsets[index.index()];
+ auto stride = opStrides[index.index()];
+ auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
+ sourceIndices[index.index()] =
+ rewriter.create<AddIOp>(loc, offset, mul).getResult();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and LoadOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
+ loadOp.indices(), sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
+ sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Folding SubViewOp and StoreOp.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
+ PatternRewriter &rewriter) const {
+ auto subViewOp =
+ dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp());
+ if (!subViewOp) {
+ return matchFailure();
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
+ storeOp.indices(), sourceIndices)))
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
+ subViewOp.source(), sourceIndices);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Hook for adding patterns.
+//===----------------------------------------------------------------------===//
+
+void mlir::populateStdLegalizationPatternsForSPIRVLowering(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass for testing just the legalization patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void SPIRVLegalization::runOnOperation() {
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+ applyPatternsGreedily(getOperation()->getRegions(), patterns);
+}
+
+std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
+ return std::make_unique<SPIRVLegalization>();
+}
+
+static PassRegistration<SPIRVLegalization>
+ pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering");
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
new file mode 100644
index 00000000000..6f3a6a82476
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
@@ -0,0 +1,35 @@
+//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==//
+
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower standard ops to SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_TD
+#define MLIR_CONVERSION_STANDARDTOSPIRV_TD
+
+include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/SPIRV/SPIRVOps.td"
+
+class BinaryOpPattern<Op src, Op tgt> :
+ Pat<(src SPV_ScalarOrVector:$l, SPV_ScalarOrVector:$r),
+ (tgt $l, $r)>;
+
+def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
+def : BinaryOpPattern<DivFOp, SPV_FDivOp>;
+def : BinaryOpPattern<MulFOp, SPV_FMulOp>;
+def : BinaryOpPattern<RemFOp, SPV_FRemOp>;
+def : BinaryOpPattern<SubFOp, SPV_FSubOp>;
+
+// Constant Op
+// TODO(ravishankarm): Handle lowering other constant types.
+def : Pat<(ConstantOp:$result $valueAttr),
+ (SPV_ConstantOp $valueAttr),
+ [(SPV_ScalarOrVector $result)]>;
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000..2aaec68f6c4
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRVectorToLLVM
+ ConvertVectorToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM
+)
+set(LIBS
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+ )
+
+add_dependencies(MLIRVectorToLLVM ${LIBS})
+target_link_libraries(MLIRVectorToLLVM ${LIBS})
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
new file mode 100644
index 00000000000..b48930c4dda
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -0,0 +1,766 @@
+//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+
+template <typename T>
+static LLVM::LLVMType getPtrToElementType(T containerType,
+ LLVMTypeConverter &lowering) {
+ return lowering.convertType(containerType.getElementType())
+ .template cast<LLVM::LLVMType>()
+ .getPointerTo();
+}
+
+// Helper to reduce vector type by one rank at front.
+static VectorType reducedVectorTypeFront(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
+}
+
+// Helper to reduce vector type by *all* but one rank at back.
+static VectorType reducedVectorTypeBack(VectorType tp) {
+ assert((tp.getRank() > 1) && "unlowerable vector type");
+ return VectorType::get(tp.getShape().take_back(), tp.getElementType());
+}
+
+// Helper that picks the proper sequence for inserting.
+static Value insertOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc, Value val1,
+ Value val2, Type llvmType, int64_t rank, int64_t pos) {
+ if (rank == 1) {
+ auto idxType = rewriter.getIndexType();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(idxType),
+ rewriter.getIntegerAttr(idxType, pos));
+ return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
+ constant);
+ }
+ return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
+ rewriter.getI64ArrayAttr(pos));
+}
+
+// Helper that picks the proper sequence for extracting.
+static Value extractOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc, Value val,
+ Type llvmType, int64_t rank, int64_t pos) {
+ if (rank == 1) {
+ auto idxType = rewriter.getIndexType();
+ auto constant = rewriter.create<LLVM::ConstantOp>(
+ loc, lowering.convertType(idxType),
+ rewriter.getIntegerAttr(idxType, pos));
+ return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
+ constant);
+ }
+ return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
+ rewriter.getI64ArrayAttr(pos));
+}
+
+class VectorBroadcastOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorBroadcastOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto broadcastOp = cast<vector::BroadcastOp>(op);
+ VectorType dstVectorType = broadcastOp.getVectorType();
+ if (lowering.convertType(dstVectorType) == nullptr)
+ return matchFailure();
+ // Rewrite when the full vector type can be lowered (which
+ // implies all 'reduced' types can be lowered too).
+ auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
+ VectorType srcVectorType =
+ broadcastOp.getSourceType().dyn_cast<VectorType>();
+ rewriter.replaceOp(
+ op, expandRanks(adaptor.source(), // source value to be expanded
+ op->getLoc(), // location of original broadcast
+ srcVectorType, dstVectorType, rewriter));
+ return matchSuccess();
+ }
+
+private:
+ // Expands the given source value over all the ranks, as defined
+ // by the source and destination type (a null source type denotes
+ // expansion from a scalar value into a vector).
+ //
+ // TODO(ajcbik): consider replacing this one-pattern lowering
+ // with a two-pattern lowering using other vector
+ // ops once all insert/extract/shuffle operations
+ // are available with lowering implemention.
+ //
+ Value expandRanks(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType,
+ ConversionPatternRewriter &rewriter) const {
+ assert((dstVectorType != nullptr) && "invalid result type in broadcast");
+ // Determine rank of source and destination.
+ int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
+ int64_t dstRank = dstVectorType.getRank();
+ int64_t curDim = dstVectorType.getDimSize(0);
+ if (srcRank < dstRank)
+ // Duplicate this rank.
+ return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
+ curDim, rewriter);
+ // If all trailing dimensions are the same, the broadcast consists of
+ // simply passing through the source value and we are done. Otherwise,
+ // any non-matching dimension forces a stretch along this rank.
+ assert((srcVectorType != nullptr) && (srcRank > 0) &&
+ (srcRank == dstRank) && "invalid rank in broadcast");
+ for (int64_t r = 0; r < dstRank; r++) {
+ if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
+ return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
+ curDim, rewriter);
+ }
+ }
+ return value;
+ }
+
+ // Picks the best way to duplicate a single rank. For the 1-D case, a
+ // single insert-elt/shuffle is the most efficient expansion. For higher
+ // dimensions, however, we need dim x insert-values on a new broadcast
+ // with one less leading dimension, which will be lowered "recursively"
+ // to matching LLVM IR.
+ // For example:
+ // v = broadcast s : f32 to vector<4x2xf32>
+ // becomes:
+ // x = broadcast s : f32 to vector<2xf32>
+ // v = [x,x,x,x]
+ // becomes:
+ // x = [s,s]
+ // v = [x,x,x,x]
+ Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
+ Type llvmType = lowering.convertType(dstVectorType);
+ assert((llvmType != nullptr) && "unlowerable vector type");
+ if (rank == 1) {
+ Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value expand =
+ insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
+ SmallVector<int32_t, 4> zeroValues(dim, 0);
+ return rewriter.create<LLVM::ShuffleVectorOp>(
+ loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
+ }
+ Value expand = expandRanks(value, loc, srcVectorType,
+ reducedVectorTypeFront(dstVectorType), rewriter);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ for (int64_t d = 0; d < dim; ++d) {
+ result =
+ insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
+ }
+ return result;
+ }
+
+ // Picks the best way to stretch a single rank. For the 1-D case, a
+ // single insert-elt/shuffle is the most efficient expansion when at
+ // a stretch. Otherwise, every dimension needs to be expanded
+ // individually and individually inserted in the resulting vector.
+ // For example:
+ // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
+ // becomes:
+ // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
+ // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
+ // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
+ // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
+ // v = [a,b,c,d]
+ // becomes:
+ // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
+ // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
+ // a = [x, y]
+ // etc.
+ Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
+ Type llvmType = lowering.convertType(dstVectorType);
+ assert((llvmType != nullptr) && "unlowerable vector type");
+ Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ bool atStretch = dim != srcVectorType.getDimSize(0);
+ if (rank == 1) {
+ assert(atStretch);
+ Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
+ Value one =
+ extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
+ Value expand =
+ insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
+ SmallVector<int32_t, 4> zeroValues(dim, 0);
+ return rewriter.create<LLVM::ShuffleVectorOp>(
+ loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
+ }
+ VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
+ VectorType redDstType = reducedVectorTypeFront(dstVectorType);
+ Type redLlvmType = lowering.convertType(redSrcType);
+ for (int64_t d = 0; d < dim; ++d) {
+ int64_t pos = atStretch ? 0 : d;
+ Value one =
+ extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
+ Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
+ result =
+ insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
+ }
+ return result;
+ }
+};
+
+class VectorShuffleOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorShuffleOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
+ auto shuffleOp = cast<vector::ShuffleOp>(op);
+ auto v1Type = shuffleOp.getV1VectorType();
+ auto v2Type = shuffleOp.getV2VectorType();
+ auto vectorType = shuffleOp.getVectorType();
+ Type llvmType = lowering.convertType(vectorType);
+ auto maskArrayAttr = shuffleOp.mask();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmType)
+ return matchFailure();
+
+ // Get rank and dimension sizes.
+ int64_t rank = vectorType.getRank();
+ assert(v1Type.getRank() == rank);
+ assert(v2Type.getRank() == rank);
+ int64_t v1Dim = v1Type.getDimSize(0);
+
+ // For rank 1, where both operands have *exactly* the same vector type,
+ // there is direct shuffle support in LLVM. Use it!
+ if (rank == 1 && v1Type == v2Type) {
+ Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+ loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
+ rewriter.replaceOp(op, shuffle);
+ return matchSuccess();
+ }
+
+ // For all other cases, insert the individual values individually.
+ Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ int64_t insPos = 0;
+ for (auto en : llvm::enumerate(maskArrayAttr)) {
+ int64_t extPos = en.value().cast<IntegerAttr>().getInt();
+ Value value = adaptor.v1();
+ if (extPos >= v1Dim) {
+ extPos -= v1Dim;
+ value = adaptor.v2();
+ }
+ Value extract =
+ extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
+ insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
+ rank, insPos++);
+ }
+ rewriter.replaceOp(op, insert);
+ return matchSuccess();
+ }
+};
+
+class VectorExtractElementOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorExtractElementOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
+ auto extractEltOp = cast<vector::ExtractElementOp>(op);
+ auto vectorType = extractEltOp.getVectorType();
+ auto llvmType = lowering.convertType(vectorType.getElementType());
+
+ // Bail if result type cannot be lowered.
+ if (!llvmType)
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
+ op, llvmType, adaptor.vector(), adaptor.position());
+ return matchSuccess();
+ }
+};
+
+class VectorExtractOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorExtractOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::ExtractOpOperandAdaptor(operands);
+ auto extractOp = cast<vector::ExtractOp>(op);
+ auto vectorType = extractOp.getVectorType();
+ auto resultType = extractOp.getResult()->getType();
+ auto llvmResultType = lowering.convertType(resultType);
+ auto positionArrayAttr = extractOp.position();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmResultType)
+ return matchFailure();
+
+ // One-shot extraction of vector from array (only requires extractvalue).
+ if (resultType.isa<VectorType>()) {
+ Value extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmResultType, adaptor.vector(), positionArrayAttr);
+ rewriter.replaceOp(op, extracted);
+ return matchSuccess();
+ }
+
+ // Potential extraction of 1-D vector from array.
+ auto *context = op->getContext();
+ Value extracted = adaptor.vector();
+ auto positionAttrs = positionArrayAttr.getValue();
+ if (positionAttrs.size() > 1) {
+ auto oneDVectorType = reducedVectorTypeBack(vectorType);
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
+ }
+
+ // Remaining extraction of element from 1-D LLVM vector
+ auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+ extracted =
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
+ rewriter.replaceOp(op, extracted);
+
+ return matchSuccess();
+ }
+};
+
+class VectorInsertElementOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorInsertElementOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
+ auto insertEltOp = cast<vector::InsertElementOp>(op);
+ auto vectorType = insertEltOp.getDestVectorType();
+ auto llvmType = lowering.convertType(vectorType);
+
+ // Bail if result type cannot be lowered.
+ if (!llvmType)
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+ op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
+ return matchSuccess();
+ }
+};
+
+class VectorInsertOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorInsertOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::InsertOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::InsertOpOperandAdaptor(operands);
+ auto insertOp = cast<vector::InsertOp>(op);
+ auto sourceType = insertOp.getSourceType();
+ auto destVectorType = insertOp.getDestVectorType();
+ auto llvmResultType = lowering.convertType(destVectorType);
+ auto positionArrayAttr = insertOp.position();
+
+ // Bail if result type cannot be lowered.
+ if (!llvmResultType)
+ return matchFailure();
+
+ // One-shot insertion of a vector into an array (only requires insertvalue).
+ if (sourceType.isa<VectorType>()) {
+ Value inserted = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmResultType, adaptor.dest(), adaptor.source(),
+ positionArrayAttr);
+ rewriter.replaceOp(op, inserted);
+ return matchSuccess();
+ }
+
+ // Potential extraction of 1-D vector from array.
+ auto *context = op->getContext();
+ Value extracted = adaptor.dest();
+ auto positionAttrs = positionArrayAttr.getValue();
+ auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto oneDVectorType = destVectorType;
+ if (positionAttrs.size() > 1) {
+ oneDVectorType = reducedVectorTypeBack(destVectorType);
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, lowering.convertType(oneDVectorType), extracted,
+ nMinusOnePositionAttrs);
+ }
+
+ // Insertion of an element into a 1-D LLVM vector.
+ auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+ Value inserted = rewriter.create<LLVM::InsertElementOp>(
+ loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
+ constant);
+
+ // Potential insertion of resulting 1-D vector into array.
+ if (positionAttrs.size() > 1) {
+ auto nMinusOnePositionAttrs =
+ ArrayAttr::get(positionAttrs.drop_back(), context);
+ inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
+ adaptor.dest(), inserted,
+ nMinusOnePositionAttrs);
+ }
+
+ rewriter.replaceOp(op, inserted);
+ return matchSuccess();
+ }
+};
+
+class VectorOuterProductOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorOuterProductOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
+ auto *ctx = op->getContext();
+ auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+ auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+ auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
+ auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
+ auto llvmArrayOfVectType = lowering.convertType(
+ cast<vector::OuterProductOp>(op).getResult()->getType());
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value a = adaptor.lhs(), b = adaptor.rhs();
+ Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value, 8> lhs, accs;
+ lhs.reserve(rankLHS);
+ accs.reserve(rankLHS);
+ for (unsigned d = 0, e = rankLHS; d < e; ++d) {
+ // shufflevector explicitly requires i32.
+ auto attr = rewriter.getI32IntegerAttr(d);
+ SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
+ auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
+ Value aD = nullptr, accD = nullptr;
+ // 1. Broadcast the element a[d] into vector aD.
+ aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
+ // 2. If acc is present, extract 1-d vector acc[d] into accD.
+ if (acc)
+ accD = rewriter.create<LLVM::ExtractValueOp>(
+ loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
+ // 3. Compute aD outer b (plus accD, if relevant).
+ Value aOuterbD =
+ accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
+ .getResult()
+ : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+ // 4. Insert as value `d` in the descriptor.
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
+ desc, aOuterbD,
+ rewriter.getI64ArrayAttr(d));
+ }
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
+class VectorTypeCastOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorTypeCastOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
+ MemRefType sourceMemRefType =
+ castOp.getOperand()->getType().cast<MemRefType>();
+ MemRefType targetMemRefType =
+ castOp.getResult()->getType().cast<MemRefType>();
+
+ // Only static shape casts supported atm.
+ if (!sourceMemRefType.hasStaticShape() ||
+ !targetMemRefType.hasStaticShape())
+ return matchFailure();
+
+ auto llvmSourceDescriptorTy =
+ operands[0]->getType().dyn_cast<LLVM::LLVMType>();
+ if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+ return matchFailure();
+ MemRefDescriptor sourceMemRef(operands[0]);
+
+ auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMType>();
+ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ return matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides =
+ getStridesAndOffset(sourceMemRefType, strides, offset);
+ bool isContiguous = (strides.back() == 1);
+ if (isContiguous) {
+ auto sizes = sourceMemRefType.getShape();
+ for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+ if (strides[index] != strides[index + 1] * sizes[index + 1]) {
+ isContiguous = false;
+ break;
+ }
+ }
+ }
+ // Only contiguous source tensors supported atm.
+ if (failed(successStrides) || !isContiguous)
+ return matchFailure();
+
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+
+ // Create descriptor.
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+ Type llvmTargetElementTy = desc.getElementType();
+ // Set allocated ptr.
+ Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
+ allocated =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
+ desc.setAllocatedPtr(rewriter, loc, allocated);
+ // Set aligned ptr.
+ Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
+ ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
+ desc.setAlignedPtr(rewriter, loc, ptr);
+ // Fill offset 0.
+ auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
+ auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+ desc.setOffset(rewriter, loc, zero);
+
+ // Fill size and stride descriptors in memref.
+ for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
+ int64_t index = indexedSize.index();
+ auto sizeAttr =
+ rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
+ auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+ desc.setSize(rewriter, loc, index, size);
+ auto strideAttr =
+ rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
+ auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+ desc.setStride(rewriter, loc, index, stride);
+ }
+
+ rewriter.replaceOp(op, {desc});
+ return matchSuccess();
+ }
+};
+
+class VectorPrintOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorPrintOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::PrintOp::getOperationName(), context,
+ typeConverter) {}
+
+ // Proof-of-concept lowering implementation that relies on a small
+ // runtime support library, which only needs to provide a few
+ // printing methods (single value for all data types, opening/closing
+ // bracket, comma, newline). The lowering fully unrolls a vector
+ // in terms of these elementary printing operations. The advantage
+ // of this approach is that the library can remain unaware of all
+ // low-level implementation details of vectors while still supporting
+ // output of any shaped and dimensioned vector. Due to full unrolling,
+ // this approach is less suited for very large vectors though.
+ //
+ // TODO(ajcbik): rely solely on libc in future? something else?
+ //
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto printOp = cast<vector::PrintOp>(op);
+ auto adaptor = vector::PrintOpOperandAdaptor(operands);
+ Type printType = printOp.getPrintType();
+
+ if (lowering.convertType(printType) == nullptr)
+ return matchFailure();
+
+ // Make sure element type has runtime support (currently just Float/Double).
+ VectorType vectorType = printType.dyn_cast<VectorType>();
+ Type eltType = vectorType ? vectorType.getElementType() : printType;
+ int64_t rank = vectorType ? vectorType.getRank() : 0;
+ Operation *printer;
+ if (eltType.isF32())
+ printer = getPrintFloat(op);
+ else if (eltType.isF64())
+ printer = getPrintDouble(op);
+ else
+ return matchFailure();
+
+ // Unroll vector into elementary print calls.
+ emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
+ emitCall(rewriter, op->getLoc(), getPrintNewline(op));
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+
+private:
+ void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
+ Value value, VectorType vectorType, Operation *printer,
+ int64_t rank) const {
+ Location loc = op->getLoc();
+ if (rank == 0) {
+ emitCall(rewriter, loc, printer, value);
+ return;
+ }
+
+ emitCall(rewriter, loc, getPrintOpen(op));
+ Operation *printComma = getPrintComma(op);
+ int64_t dim = vectorType.getDimSize(0);
+ for (int64_t d = 0; d < dim; ++d) {
+ auto reducedType =
+ rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
+ auto llvmType = lowering.convertType(
+ rank > 1 ? reducedType : vectorType.getElementType());
+ Value nestedVal =
+ extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
+ emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
+ if (d != dim - 1)
+ emitCall(rewriter, loc, printComma);
+ }
+ emitCall(rewriter, loc, getPrintClose(op));
+ }
+
+ // Helper to emit a call.
+ static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
+ Operation *ref, ValueRange params = ValueRange()) {
+ rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
+ rewriter.getSymbolRefAttr(ref), params);
+ }
+
+ // Helper for printer method declaration (first hit) and lookup.
+ static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
+ StringRef name, ArrayRef<LLVM::LLVMType> params) {
+ auto module = op->getParentOfType<ModuleOp>();
+ auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ if (func)
+ return func;
+ OpBuilder moduleBuilder(module.getBodyRegion());
+ return moduleBuilder.create<LLVM::LLVMFuncOp>(
+ op->getLoc(), name,
+ LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
+ params, /*isVarArg=*/false));
+ }
+
+ // Helpers for method names.
+ Operation *getPrintFloat(Operation *op) const {
+ LLVM::LLVMDialect *dialect = lowering.getDialect();
+ return getPrint(op, dialect, "print_f32",
+ LLVM::LLVMType::getFloatTy(dialect));
+ }
+ Operation *getPrintDouble(Operation *op) const {
+ LLVM::LLVMDialect *dialect = lowering.getDialect();
+ return getPrint(op, dialect, "print_f64",
+ LLVM::LLVMType::getDoubleTy(dialect));
+ }
+ Operation *getPrintOpen(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_open", {});
+ }
+ Operation *getPrintClose(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_close", {});
+ }
+ Operation *getPrintComma(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_comma", {});
+ }
+ Operation *getPrintNewline(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_newline", {});
+ }
+};
+
+/// Populate the given list with patterns that convert from Vector to LLVM.
+void mlir::populateVectorToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
+ VectorExtractElementOpConversion, VectorExtractOpConversion,
+ VectorInsertElementOpConversion, VectorInsertOpConversion,
+ VectorOuterProductOpConversion, VectorTypeCastOpConversion,
+ VectorPrintOpConversion>(converter.getDialect()->getContext(),
+ converter);
+}
+
+namespace {
+struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+void LowerVectorToLLVMPass::runOnModule() {
+ // Convert to the LLVM IR dialect using the converter defined above.
+ OwningRewritePatternList patterns;
+ LLVMTypeConverter converter(&getContext());
+ populateVectorToLLVMConversionPatterns(converter, patterns);
+ populateStdToLLVMConversionPatterns(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+ if (failed(
+ applyPartialConversion(getModule(), target, patterns, &converter))) {
+ signalPassFailure();
+ }
+}
+
+OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
+ return new LowerVectorToLLVMPass();
+}
+
+static PassRegistration<LowerVectorToLLVMPass>
+ pass("convert-vector-to-llvm",
+ "Lower the operations from the vector dialect into the LLVM dialect");
diff --git a/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt b/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt
new file mode 100644
index 00000000000..e213bc9bcce
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_llvm_library(MLIRVectorToLoops
+ ConvertVectorToLoops.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLoops
+)
+set(LIBS
+ MLIRLLVMIR
+ MLIRTransforms
+ LLVMCore
+ LLVMSupport
+ )
+
+add_dependencies(MLIRVectorToLoops ${LIBS})
+target_link_libraries(MLIRVectorToLoops ${LIBS})
diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
new file mode 100644
index 00000000000..3ed031b985a
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
@@ -0,0 +1,358 @@
+//===- VectorToLoops.cpp - Conversion from Vector to mix of Loops and Std -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-dependent lowering of vector transfer operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+
+using namespace mlir;
+using vector::TransferReadOp;
+using vector::TransferWriteOp;
+
+/// Analyzes the `transfer` to find an access dimension along the fastest remote
+/// MemRef dimension. If such a dimension with coalescing properties is found,
+/// `pivs` and `vectorView` are swapped so that the invocation of
+/// LoopNestBuilder captures it in the innermost loop.
+template <typename TransferOpTy>
+static void coalesceCopy(TransferOpTy transfer,
+ SmallVectorImpl<edsc::ValueHandle *> *pivs,
+ edsc::VectorView *vectorView) {
+ // rank of the remote memory access, coalescing behavior occurs on the
+ // innermost memory dimension.
+ auto remoteRank = transfer.getMemRefType().getRank();
+ // Iterate over the results expressions of the permutation map to determine
+ // the loop order for creating pointwise copies between remote and local
+ // memories.
+ int coalescedIdx = -1;
+ auto exprs = transfer.permutation_map().getResults();
+ for (auto en : llvm::enumerate(exprs)) {
+ auto dim = en.value().template dyn_cast<AffineDimExpr>();
+ if (!dim) {
+ continue;
+ }
+ auto memRefDim = dim.getPosition();
+ if (memRefDim == remoteRank - 1) {
+ // memRefDim has coalescing properties, it should be swapped in the last
+ // position.
+ assert(coalescedIdx == -1 && "Unexpected > 1 coalesced indices");
+ coalescedIdx = en.index();
+ }
+ }
+ if (coalescedIdx >= 0) {
+ std::swap(pivs->back(), (*pivs)[coalescedIdx]);
+ vectorView->swapRanges(pivs->size() - 1, coalescedIdx);
+ }
+}
+
+/// Emits remote memory accesses that are clipped to the boundaries of the
+/// MemRef.
+template <typename TransferOpTy>
+static SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer,
+ edsc::MemRefView &view,
+ ArrayRef<edsc::IndexHandle> ivs) {
+ using namespace mlir::edsc;
+ using namespace edsc::op;
+ using edsc::intrinsics::select;
+
+ IndexHandle zero(index_t(0)), one(index_t(1));
+ SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices());
+ SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
+ memRefAccess.size(), edsc::IndexHandle());
+
+ // Indices accessing to remote memory are clipped and their expressions are
+ // returned in clippedScalarAccessExprs.
+ for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size();
+ ++memRefDim) {
+ // Linear search on a small number of entries.
+ int loopIndex = -1;
+ auto exprs = transfer.permutation_map().getResults();
+ for (auto en : llvm::enumerate(exprs)) {
+ auto expr = en.value();
+ auto dim = expr.template dyn_cast<AffineDimExpr>();
+ // Sanity check.
+ assert(
+ (dim || expr.template cast<AffineConstantExpr>().getValue() == 0) &&
+ "Expected dim or 0 in permutationMap");
+ if (dim && memRefDim == dim.getPosition()) {
+ loopIndex = en.index();
+ break;
+ }
+ }
+
+ // We cannot distinguish atm between unrolled dimensions that implement
+ // the "always full" tile abstraction and need clipping from the other
+ // ones. So we conservatively clip everything.
+ auto N = view.ub(memRefDim);
+ auto i = memRefAccess[memRefDim];
+ if (loopIndex < 0) {
+ auto N_minus_1 = N - one;
+ auto select_1 = select(i < N, i, N_minus_1);
+ clippedScalarAccessExprs[memRefDim] = select(i < zero, zero, select_1);
+ } else {
+ auto ii = ivs[loopIndex];
+ auto i_plus_ii = i + ii;
+ auto N_minus_1 = N - one;
+ auto select_1 = select(i_plus_ii < N, i_plus_ii, N_minus_1);
+ clippedScalarAccessExprs[memRefDim] =
+ select(i_plus_ii < zero, zero, select_1);
+ }
+ }
+
+ return clippedScalarAccessExprs;
+}
+
+namespace {
+
+using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
+
+/// Implements lowering of TransferReadOp and TransferWriteOp to a
+/// proper abstraction for the hardware.
+///
+/// For now, we only emit a simple loop nest that performs clipped pointwise
+/// copies from a remote to a locally allocated memory.
+///
+/// Consider the case:
+///
+/// ```mlir
+/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
+/// // vector<32x256xf32> and pad with %f0 to handle the boundary case:
+/// %f0 = constant 0.0f : f32
+/// loop.for %i0 = 0 to %0 {
+/// loop.for %i1 = 0 to %1 step %c256 {
+/// loop.for %i2 = 0 to %2 step %c32 {
+/// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0
+/// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
+/// memref<?x?x?xf32>, vector<32x256xf32>
+/// }}}
+/// ```
+///
+/// The rewriters construct loop and indices that access MemRef A in a pattern
+/// resembling the following (while guaranteeing an always full-tile
+/// abstraction):
+///
+/// ```mlir
+/// loop.for %d2 = 0 to %c256 {
+/// loop.for %d1 = 0 to %c32 {
+/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
+/// %tmp[%d2, %d1] = %s
+/// }
+/// }
+/// ```
+///
+/// In the current state, only a clipping transfer is implemented by `clip`,
+/// which creates individual indexing expressions of the form:
+///
+/// ```mlir-dsc
+/// auto condMax = i + ii < N;
+/// auto max = select(condMax, i + ii, N - one)
+/// auto cond = i + ii < zero;
+/// select(cond, zero, max);
+/// ```
+///
+/// In the future, clipping should not be the only way and instead we should
+/// load vectors + mask them. Similarly on the write side, load/mask/store for
+/// implementing RMW behavior.
+///
+/// Lowers TransferOp into a combination of:
+/// 1. local memory allocation;
+/// 2. perfect loop nest over:
+/// a. scalar load/stores from local buffers (viewed as a scalar memref);
+/// a. scalar store/load to original memref (with clipping).
+/// 3. vector_load/store
+/// 4. local memory deallocation.
+/// Minor variations occur depending on whether a TransferReadOp or
+/// a TransferWriteOp is rewritten.
+template <typename TransferOpTy>
+struct VectorTransferRewriter : public RewritePattern {
+ explicit VectorTransferRewriter(MLIRContext *context)
+ : RewritePattern(TransferOpTy::getOperationName(), 1, context) {}
+
+ /// Used for staging the transfer in a local scalar buffer.
+ MemRefType tmpMemRefType(TransferOpTy transfer) const {
+ auto vectorType = transfer.getVectorType();
+ return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
+ {}, 0);
+ }
+
+ /// Performs the rewrite.
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Lowers TransferReadOp into a combination of:
+/// 1. local memory allocation;
+/// 2. perfect loop nest over:
+/// a. scalar load from local buffers (viewed as a scalar memref);
+/// a. scalar store to original memref (with clipping).
+/// 3. vector_load from local buffer (viewed as a memref<1 x vector>);
+/// 4. local memory deallocation.
+///
+/// Lowers the data transfer part of a TransferReadOp while ensuring no
+/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
+/// clipping. This means that a given value in memory can be read multiple
+/// times and concurrently.
+///
+/// Important notes about clipping and "full-tiles only" abstraction:
+/// =================================================================
+/// When using clipping for dealing with boundary conditions, the same edge
+/// value will appear multiple times (a.k.a edge padding). This is fine if the
+/// subsequent vector operations are all data-parallel but **is generally
+/// incorrect** in the presence of reductions or extract operations.
+///
+/// More generally, clipping is a scalar abstraction that is expected to work
+/// fine as a baseline for CPUs and GPUs but not for vector_load and DMAs.
+/// To deal with real vector_load and DMAs, a "padded allocation + view"
+/// abstraction with the ability to read out-of-memref-bounds (but still within
+/// the allocated region) is necessary.
+///
+/// Whether using scalar loops or vector_load/DMAs to perform the transfer,
+/// junk values will be materialized in the vectors and generally need to be
+/// filtered out and replaced by the "neutral element". This neutral element is
+/// op-dependent so, in the future, we expect to create a vector filter and
+/// apply it to a splatted constant vector with the proper neutral element at
+/// each ssa-use. This filtering is not necessary for pure data-parallel
+/// operations.
+///
+/// In the case of vector_store/DMAs, Read-Modify-Write will be required, which
+/// also have concurrency implications. Note that by using clipped scalar stores
+/// in the presence of data-parallel only operations, we generate code that
+/// writes the same value multiple time on the edge locations.
+///
+/// TODO(ntv): implement alternatives to clipping.
+/// TODO(ntv): support non-data-parallel operations.
+
+/// Performs the rewrite.
+template <>
+PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ using namespace mlir::edsc;
+ using namespace mlir::edsc::op;
+ using namespace mlir::edsc::intrinsics;
+ using IndexedValue =
+ TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
+
+ TransferReadOp transfer = cast<TransferReadOp>(op);
+
+ // 1. Setup all the captures.
+ ScopedContext scope(rewriter, transfer.getLoc());
+ IndexedValue remote(transfer.memref());
+ MemRefView view(transfer.memref());
+ VectorView vectorView(transfer.vector());
+ SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
+ SmallVector<ValueHandle *, 8> pivs =
+ makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
+ coalesceCopy(transfer, &pivs, &vectorView);
+
+ auto lbs = vectorView.getLbs();
+ auto ubs = vectorView.getUbs();
+ SmallVector<ValueHandle, 8> steps;
+ steps.reserve(vectorView.getSteps().size());
+ for (auto step : vectorView.getSteps())
+ steps.push_back(constant_index(step));
+
+ // 2. Emit alloc-copy-load-dealloc.
+ ValueHandle tmp = alloc(tmpMemRefType(transfer));
+ IndexedValue local(tmp);
+ ValueHandle vec = vector_type_cast(tmp);
+ LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
+ // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
+ local(ivs) = remote(clip(transfer, view, ivs));
+ });
+ ValueHandle vectorValue = std_load(vec);
+ (dealloc(tmp)); // vexing parse
+
+ // 3. Propagate.
+ rewriter.replaceOp(op, vectorValue.getValue());
+ return matchSuccess();
+}
+
+/// Lowers TransferWriteOp into a combination of:
+/// 1. local memory allocation;
+/// 2. vector_store to local buffer (viewed as a memref<1 x vector>);
+/// 3. perfect loop nest over:
+/// a. scalar load from local buffers (viewed as a scalar memref);
+/// a. scalar store to original memref (with clipping).
+/// 4. local memory deallocation.
+///
+/// More specifically, lowers the data transfer part while ensuring no
+/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
+/// clipping. This means that a given value in memory can be written to multiple
+/// times and concurrently.
+///
+/// See `Important notes about clipping and full-tiles only abstraction` in the
+/// description of `readClipped` above.
+///
+/// TODO(ntv): implement alternatives to clipping.
+/// TODO(ntv): support non-data-parallel operations.
+template <>
+PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ using namespace mlir::edsc;
+ using namespace mlir::edsc::op;
+ using namespace mlir::edsc::intrinsics;
+ using IndexedValue =
+ TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
+
+ TransferWriteOp transfer = cast<TransferWriteOp>(op);
+
+ // 1. Setup all the captures.
+ ScopedContext scope(rewriter, transfer.getLoc());
+ IndexedValue remote(transfer.memref());
+ MemRefView view(transfer.memref());
+ ValueHandle vectorValue(transfer.vector());
+ VectorView vectorView(transfer.vector());
+ SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
+ SmallVector<ValueHandle *, 8> pivs =
+ makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
+ coalesceCopy(transfer, &pivs, &vectorView);
+
+ auto lbs = vectorView.getLbs();
+ auto ubs = vectorView.getUbs();
+ SmallVector<ValueHandle, 8> steps;
+ steps.reserve(vectorView.getSteps().size());
+ for (auto step : vectorView.getSteps())
+ steps.push_back(constant_index(step));
+
+ // 2. Emit alloc-store-copy-dealloc.
+ ValueHandle tmp = alloc(tmpMemRefType(transfer));
+ IndexedValue local(tmp);
+ ValueHandle vec = vector_type_cast(tmp);
+ std_store(vectorValue, vec);
+ LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
+ // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
+ remote(clip(transfer, view, ivs)) = local(ivs);
+ });
+ (dealloc(tmp)); // vexing parse...
+
+ rewriter.eraseOp(op);
+ return matchSuccess();
+}
+
+} // namespace
+
+void mlir::populateVectorToAffineLoopsConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<VectorTransferRewriter<vector::TransferReadOp>,
+ VectorTransferRewriter<vector::TransferWriteOp>>(context);
+}
OpenPOWER on IntegriCloud