diff options
Diffstat (limited to 'mlir/lib/Conversion')
34 files changed, 8375 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp new file mode 100644 index 00000000000..e9a9ca82f51 --- /dev/null +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -0,0 +1,550 @@ +//===- AffineToStandard.cpp - Lower affine constructs to primitives -------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file lowers affine constructs (If and For statements, AffineApply +// operations) within a function into their standard If and For equivalent ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +// Visit affine expressions recursively and build the sequence of operations +// that correspond to it. Visitation functions return an Value of the +// expression subtree they visited or `nullptr` on error. +class AffineApplyExpander + : public AffineExprVisitor<AffineApplyExpander, Value> { +public: + // This internal class expects arguments to be non-null, checks must be + // performed at the call site. + AffineApplyExpander(OpBuilder &builder, ArrayRef<Value> dimValues, + ArrayRef<Value> symbolValues, Location loc) + : builder(builder), dimValues(dimValues), symbolValues(symbolValues), + loc(loc) {} + + template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) { + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + if (!lhs || !rhs) + return nullptr; + auto op = builder.create<OpTy>(loc, lhs, rhs); + return op.getResult(); + } + + Value visitAddExpr(AffineBinaryOpExpr expr) { + return buildBinaryExpr<AddIOp>(expr); + } + + Value visitMulExpr(AffineBinaryOpExpr expr) { + return buildBinaryExpr<MulIOp>(expr); + } + + // Euclidean modulo operation: negative RHS is not allowed. + // Remainder of the euclidean integer division is always non-negative. + // + // Implemented as + // + // a mod b = + // let remainder = srem a, b; + // negative = a < 0 in + // select negative, remainder + b, remainder. + Value visitModExpr(AffineBinaryOpExpr expr) { + auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); + if (!rhsConst) { + emitError( + loc, + "semi-affine expressions (modulo by non-const) are not supported"); + return nullptr; + } + if (rhsConst.getValue() <= 0) { + emitError(loc, "modulo by non-positive value is not supported"); + return nullptr; + } + + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + assert(lhs && rhs && "unexpected affine expr lowering failure"); + + Value remainder = builder.create<SignedRemIOp>(loc, lhs, rhs); + Value zeroCst = builder.create<ConstantIndexOp>(loc, 0); + Value isRemainderNegative = + builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst); + Value correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs); + Value result = builder.create<SelectOp>(loc, isRemainderNegative, + correctedRemainder, remainder); + return result; + } + + // Floor division operation (rounds towards negative infinity). + // + // For positive divisors, it can be implemented without branching and with a + // single division operation as + // + // a floordiv b = + // let negative = a < 0 in + // let absolute = negative ? -a - 1 : a in + // let quotient = absolute / b in + // negative ? -quotient - 1 : quotient + Value visitFloorDivExpr(AffineBinaryOpExpr expr) { + auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); + if (!rhsConst) { + emitError( + loc, + "semi-affine expressions (division by non-const) are not supported"); + return nullptr; + } + if (rhsConst.getValue() <= 0) { + emitError(loc, "division by non-positive value is not supported"); + return nullptr; + } + + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + assert(lhs && rhs && "unexpected affine expr lowering failure"); + + Value zeroCst = builder.create<ConstantIndexOp>(loc, 0); + Value noneCst = builder.create<ConstantIndexOp>(loc, -1); + Value negative = + builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs); + Value dividend = + builder.create<SelectOp>(loc, negative, negatedDecremented, lhs); + Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs); + Value correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient); + Value result = + builder.create<SelectOp>(loc, negative, correctedQuotient, quotient); + return result; + } + + // Ceiling division operation (rounds towards positive infinity). + // + // For positive divisors, it can be implemented without branching and with a + // single division operation as + // + // a ceildiv b = + // let negative = a <= 0 in + // let absolute = negative ? -a : a - 1 in + // let quotient = absolute / b in + // negative ? -quotient : quotient + 1 + Value visitCeilDivExpr(AffineBinaryOpExpr expr) { + auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); + if (!rhsConst) { + emitError(loc) << "semi-affine expressions (division by non-const) are " + "not supported"; + return nullptr; + } + if (rhsConst.getValue() <= 0) { + emitError(loc, "division by non-positive value is not supported"); + return nullptr; + } + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + assert(lhs && rhs && "unexpected affine expr lowering failure"); + + Value zeroCst = builder.create<ConstantIndexOp>(loc, 0); + Value oneCst = builder.create<ConstantIndexOp>(loc, 1); + Value nonPositive = + builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst); + Value negated = builder.create<SubIOp>(loc, zeroCst, lhs); + Value decremented = builder.create<SubIOp>(loc, lhs, oneCst); + Value dividend = + builder.create<SelectOp>(loc, nonPositive, negated, decremented); + Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs); + Value negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient); + Value incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst); + Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient, + incrementedQuotient); + return result; + } + + Value visitConstantExpr(AffineConstantExpr expr) { + auto valueAttr = + builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); + auto op = + builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr); + return op.getResult(); + } + + Value visitDimExpr(AffineDimExpr expr) { + assert(expr.getPosition() < dimValues.size() && + "affine dim position out of range"); + return dimValues[expr.getPosition()]; + } + + Value visitSymbolExpr(AffineSymbolExpr expr) { + assert(expr.getPosition() < symbolValues.size() && + "symbol dim position out of range"); + return symbolValues[expr.getPosition()]; + } + +private: + OpBuilder &builder; + ArrayRef<Value> dimValues; + ArrayRef<Value> symbolValues; + + Location loc; +}; +} // namespace + +// Create a sequence of operations that implement the `expr` applied to the +// given dimension and symbol values. +mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, + AffineExpr expr, ArrayRef<Value> dimValues, + ArrayRef<Value> symbolValues) { + return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); +} + +// Create a sequence of operations that implement the `affineMap` applied to +// the given `operands` (as it it were an AffineApplyOp). +Optional<SmallVector<Value, 8>> static expandAffineMap( + OpBuilder &builder, Location loc, AffineMap affineMap, + ArrayRef<Value> operands) { + auto numDims = affineMap.getNumDims(); + auto expanded = functional::map( + [numDims, &builder, loc, operands](AffineExpr expr) { + return expandAffineExpr(builder, loc, expr, + operands.take_front(numDims), + operands.drop_front(numDims)); + }, + affineMap.getResults()); + if (llvm::all_of(expanded, [](Value v) { return v; })) + return expanded; + return None; +} + +// Given a range of values, emit the code that reduces them with "min" or "max" +// depending on the provided comparison predicate. The predicate defines which +// comparison to perform, "lt" for "min", "gt" for "max" and is used for the +// `cmpi` operation followed by the `select` operation: +// +// %cond = cmpi "predicate" %v0, %v1 +// %result = select %cond, %v0, %v1 +// +// Multiple values are scanned in a linear sequence. This creates a data +// dependences that wouldn't exist in a tree reduction, but is easier to +// recognize as a reduction by the subsequent passes. +static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, + ArrayRef<Value> values, + OpBuilder &builder) { + assert(!llvm::empty(values) && "empty min/max chain"); + + auto valueIt = values.begin(); + Value value = *valueIt++; + for (; valueIt != values.end(); ++valueIt) { + auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt); + value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt); + } + + return value; +} + +// Emit instructions that correspond to the affine map in the lower bound +// applied to the respective operands, and compute the maximum value across +// the results. +Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { + SmallVector<Value, 8> boundOperands(op.getLowerBoundOperands()); + auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(), + boundOperands); + if (!lbValues) + return nullptr; + return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues, + builder); +} + +// Emit instructions that correspond to the affine map in the upper bound +// applied to the respective operands, and compute the minimum value across +// the results. +Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { + SmallVector<Value, 8> boundOperands(op.getUpperBoundOperands()); + auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(), + boundOperands); + if (!ubValues) + return nullptr; + return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues, + builder); +} + +namespace { +// Affine terminators are removed. +class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> { +public: + using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineTerminatorOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<loop::TerminatorOp>(op); + return matchSuccess(); + } +}; + +class AffineForLowering : public OpRewritePattern<AffineForOp> { +public: + using OpRewritePattern<AffineForOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineForOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lowerBound = lowerAffineLowerBound(op, rewriter); + Value upperBound = lowerAffineUpperBound(op, rewriter); + Value step = rewriter.create<ConstantIndexOp>(loc, op.getStep()); + auto f = rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step); + f.region().getBlocks().clear(); + rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +class AffineIfLowering : public OpRewritePattern<AffineIfOp> { +public: + using OpRewritePattern<AffineIfOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineIfOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // Now we just have to handle the condition logic. + auto integerSet = op.getIntegerSet(); + Value zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0); + SmallVector<Value, 8> operands(op.getOperands()); + auto operandsRef = llvm::makeArrayRef(operands); + + // Calculate cond as a conjunction without short-circuiting. + Value cond = nullptr; + for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) { + AffineExpr constraintExpr = integerSet.getConstraint(i); + bool isEquality = integerSet.isEq(i); + + // Build and apply an affine expression + auto numDims = integerSet.getNumDims(); + Value affResult = expandAffineExpr(rewriter, loc, constraintExpr, + operandsRef.take_front(numDims), + operandsRef.drop_front(numDims)); + if (!affResult) + return matchFailure(); + auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; + Value cmpVal = + rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant); + cond = + cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal; + } + cond = cond ? cond + : rewriter.create<ConstantIntOp>(loc, /*value=*/1, /*width=*/1); + + bool hasElseRegion = !op.elseRegion().empty(); + auto ifOp = rewriter.create<loop::IfOp>(loc, cond, hasElseRegion); + rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); + ifOp.thenRegion().back().erase(); + if (hasElseRegion) { + rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back()); + ifOp.elseRegion().back().erase(); + } + + // Ok, we're done! + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +// Convert an "affine.apply" operation into a sequence of arithmetic +// operations using the StandardOps dialect. +class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> { +public: + using OpRewritePattern<AffineApplyOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineApplyOp op, + PatternRewriter &rewriter) const override { + auto maybeExpandedMap = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), + llvm::to_vector<8>(op.getOperands())); + if (!maybeExpandedMap) + return matchFailure(); + rewriter.replaceOp(op, *maybeExpandedMap); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.load' operation to its operands, and +// feed the results to a newly created 'std.load' operation (which replaces the +// original 'affine.load'). +class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> { +public: + using OpRewritePattern<AffineLoadOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineLoadOp op, + PatternRewriter &rewriter) const override { + // Expand affine map from 'affineLoadOp'. + SmallVector<Value, 8> indices(op.getMapOperands()); + auto resultOperands = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); + if (!resultOperands) + return matchFailure(); + + // Build std.load memref[expandedMap.results]. + rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.prefetch' operation to its operands, and +// feed the results to a newly created 'std.prefetch' operation (which replaces +// the original 'affine.prefetch'). +class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> { +public: + using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffinePrefetchOp op, + PatternRewriter &rewriter) const override { + // Expand affine map from 'affinePrefetchOp'. + SmallVector<Value, 8> indices(op.getMapOperands()); + auto resultOperands = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); + if (!resultOperands) + return matchFailure(); + + // Build std.prefetch memref[expandedMap.results]. + rewriter.replaceOpWithNewOp<PrefetchOp>( + op, op.memref(), *resultOperands, op.isWrite(), + op.localityHint().getZExtValue(), op.isDataCache()); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.store' operation to its operands, and +// feed the results to a newly created 'std.store' operation (which replaces the +// original 'affine.store'). +class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> { +public: + using OpRewritePattern<AffineStoreOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineStoreOp op, + PatternRewriter &rewriter) const override { + // Expand affine map from 'affineStoreOp'. + SmallVector<Value, 8> indices(op.getMapOperands()); + auto maybeExpandedMap = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); + if (!maybeExpandedMap) + return matchFailure(); + + // Build std.store valueToStore, memref[expandedMap.results]. + rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(), + op.getMemRef(), *maybeExpandedMap); + return matchSuccess(); + } +}; + +// Apply the affine maps from an 'affine.dma_start' operation to each of their +// respective map operands, and feed the results to a newly created +// 'std.dma_start' operation (which replaces the original 'affine.dma_start'). +class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> { +public: + using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineDmaStartOp op, + PatternRewriter &rewriter) const override { + SmallVector<Value, 8> operands(op.getOperands()); + auto operandsRef = llvm::makeArrayRef(operands); + + // Expand affine map for DMA source memref. + auto maybeExpandedSrcMap = expandAffineMap( + rewriter, op.getLoc(), op.getSrcMap(), + operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1)); + if (!maybeExpandedSrcMap) + return matchFailure(); + // Expand affine map for DMA destination memref. + auto maybeExpandedDstMap = expandAffineMap( + rewriter, op.getLoc(), op.getDstMap(), + operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1)); + if (!maybeExpandedDstMap) + return matchFailure(); + // Expand affine map for DMA tag memref. + auto maybeExpandedTagMap = expandAffineMap( + rewriter, op.getLoc(), op.getTagMap(), + operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1)); + if (!maybeExpandedTagMap) + return matchFailure(); + + // Build std.dma_start operation with affine map results. + rewriter.replaceOpWithNewOp<DmaStartOp>( + op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(), + *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(), + *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride()); + return matchSuccess(); + } +}; + +// Apply the affine map from an 'affine.dma_wait' operation tag memref, +// and feed the results to a newly created 'std.dma_wait' operation (which +// replaces the original 'affine.dma_wait'). +class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> { +public: + using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineDmaWaitOp op, + PatternRewriter &rewriter) const override { + // Expand affine map for DMA tag memref. + SmallVector<Value, 8> indices(op.getTagIndices()); + auto maybeExpandedTagMap = + expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); + if (!maybeExpandedTagMap) + return matchFailure(); + + // Build std.dma_wait operation with affine map results. + rewriter.replaceOpWithNewOp<DmaWaitOp>( + op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements()); + return matchSuccess(); + } +}; + +} // end namespace + +void mlir::populateAffineToStdConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert< + AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering, + AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering, + AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx); +} + +namespace { +class LowerAffinePass : public FunctionPass<LowerAffinePass> { + void runOnFunction() override { + OwningRewritePatternList patterns; + populateAffineToStdConversionPatterns(patterns, &getContext()); + ConversionTarget target(getContext()); + target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>(); + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; +} // namespace + +/// Lowers If and For operations within a function into their lower level CFG +/// equivalent blocks. +std::unique_ptr<OpPassBase<FuncOp>> mlir::createLowerAffinePass() { + return std::make_unique<LowerAffinePass>(); +} + +static PassRegistration<LowerAffinePass> + pass("lower-affine", + "Lower If, For, AffineApply operations to primitive equivalents"); diff --git a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt new file mode 100644 index 00000000000..33f7db7abc4 --- /dev/null +++ b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt @@ -0,0 +1,24 @@ +add_llvm_library(MLIRAffineToStandard + AffineToStandard.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AffineToStandard +) +add_dependencies( + MLIRAffineToStandard + + MLIRAffineOps + MLIRStandardOps + MLIRIR + LLVMCore + LLVMSupport +) +target_link_libraries( + MLIRAffineToStandard + + MLIRAffineOps + MLIRStandardOps + MLIRIR + LLVMCore + LLVMSupport +) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt new file mode 100644 index 00000000000..c791d214d30 --- /dev/null +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -0,0 +1,12 @@ +add_subdirectory(AffineToStandard) +add_subdirectory(GPUToCUDA) +add_subdirectory(GPUToNVVM) +add_subdirectory(GPUToROCDL) +add_subdirectory(GPUToSPIRV) +add_subdirectory(LinalgToLLVM) +add_subdirectory(LoopsToGPU) +add_subdirectory(LoopToStandard) +add_subdirectory(StandardToLLVM) +add_subdirectory(StandardToSPIRV) +add_subdirectory(VectorToLLVM) +add_subdirectory(VectorToLoops) diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h new file mode 100644 index 00000000000..63bc15173be --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -0,0 +1,85 @@ +//===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ +#define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +#include "llvm/ADT/StringSwitch.h" + +namespace mlir { + +// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension +// that Op operates on. Op is assumed to return an `std.index` value and +// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on +// `indexBitwidth`, sign-extend or truncate the resulting value to match the +// bitwidth expected by the consumers of the value. +template <typename Op, typename XOp, typename YOp, typename ZOp> +struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering { +private: + enum dimension { X = 0, Y = 1, Z = 2, invalid }; + unsigned indexBitwidth; + + static dimension dimensionToIndex(Op op) { + return llvm::StringSwitch<dimension>(op.dimension()) + .Case("x", X) + .Case("y", Y) + .Case("z", Z) + .Default(invalid); + } + + static unsigned getIndexBitWidth(LLVMTypeConverter &type_converter) { + auto dialect = type_converter.getDialect(); + return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits(); + } + +public: + explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(Op::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), + indexBitwidth(getIndexBitWidth(lowering_)) {} + + // Convert the kernel arguments to an LLVM type, preserve the rest. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto dialect = lowering.getDialect(); + Value newOp; + switch (dimensionToIndex(cast<Op>(op))) { + case X: + newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect)); + break; + case Y: + newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect)); + break; + case Z: + newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect)); + break; + default: + return matchFailure(); + } + + if (indexBitwidth > 32) { + newOp = rewriter.create<LLVM::SExtOp>( + loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create<LLVM::TruncOp>( + loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp); + } + + rewriter.replaceOp(op, {newOp}); + return matchSuccess(); + } +}; + +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h new file mode 100644 index 00000000000..b75c1bf2d7b --- /dev/null +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -0,0 +1,100 @@ +//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ +#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" + +namespace mlir { + +/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` +/// depending on the element type that Op operates upon. The function +/// declaration is added in case it was not added before. +/// +/// Example with NVVM: +/// %exp_f32 = std.exp %arg_f32 : f32 +/// +/// will be transformed into +/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float +template <typename SourceOp> +struct OpToFuncCallLowering : public LLVMOpLowering { +public: + explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, + StringRef f64Func) + : LLVMOpLowering(SourceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), + f32Func(f32Func), f64Func(f64Func) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + using LLVM::LLVMFuncOp; + using LLVM::LLVMType; + + static_assert( + std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, + "expected single result op"); + + LLVMType resultType = lowering.convertType(op->getResult(0)->getType()) + .template cast<LLVM::LLVMType>(); + LLVMType funcType = getFunctionType(resultType, operands); + StringRef funcName = getFunctionName(resultType); + if (funcName.empty()) + return matchFailure(); + + LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); + auto callOp = rewriter.create<LLVM::CallOp>( + op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands); + rewriter.replaceOp(op, {callOp.getResult(0)}); + return matchSuccess(); + } + +private: + LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, + ArrayRef<Value> operands) const { + using LLVM::LLVMType; + SmallVector<LLVMType, 1> operandTypes; + for (Value operand : operands) { + operandTypes.push_back(operand->getType().cast<LLVMType>()); + } + return LLVMType::getFunctionTy(resultType, operandTypes, + /*isVarArg=*/false); + } + + StringRef getFunctionName(LLVM::LLVMType type) const { + if (type.isFloatTy()) + return f32Func; + if (type.isDoubleTy()) + return f64Func; + return ""; + } + + LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, + LLVM::LLVMType funcType, + Operation *op) const { + using LLVM::LLVMFuncOp; + + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName); + if (funcOp) + return cast<LLVMFuncOp>(*funcOp); + + mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>()); + return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); + } + + const std::string f32Func; + const std::string f64Func; +}; + +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt new file mode 100644 index 00000000000..4eddb787493 --- /dev/null +++ b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt @@ -0,0 +1,16 @@ +if(MLIR_CUDA_CONVERSIONS_ENABLED) + llvm_map_components_to_libnames(nvptx "NVPTX") + + add_llvm_library(MLIRGPUtoCUDATransforms + ConvertKernelFuncToCubin.cpp + ConvertLaunchFuncToCudaCalls.cpp + ) + target_link_libraries(MLIRGPUtoCUDATransforms + MLIRGPU + MLIRLLVMIR + MLIRNVVMIR + MLIRPass + MLIRTargetNVVMIR + ${nvptx} + ) +endif() diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp new file mode 100644 index 00000000000..66a2e66f99a --- /dev/null +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -0,0 +1,167 @@ +//===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert gpu kernel functions into a +// corresponding binary blob that can be executed on a CUDA GPU. Currently +// only translates the function itself but no dependencies. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h" + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/NVVMIR.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" + +using namespace mlir; + +namespace { +// TODO(herhut): Move to shared location. +static constexpr const char *kCubinAnnotation = "nvvm.cubin"; + +/// A pass converting tagged kernel modules to cubin blobs. +/// +/// If tagged as a kernel module, each contained function is translated to NVVM +/// IR and further to PTX. A user provided CubinGenerator compiles the PTX to +/// GPU binary code, which is then attached as an attribute to the function. The +/// function body is erased. +class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> { +public: + GpuKernelToCubinPass( + CubinGenerator cubinGenerator = compilePtxToCubinForTesting) + : cubinGenerator(cubinGenerator) {} + + void runOnModule() override { + ModuleOp module = getModule(); + if (!module.getAttrOfType<UnitAttr>( + gpu::GPUDialect::getKernelModuleAttrName()) || + !module.getName()) + return; + + // Make sure the NVPTX target is initialized. + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + + auto llvmModule = translateModuleToNVVMIR(module); + if (!llvmModule) + return signalPassFailure(); + + // Translate the module to CUBIN and attach the result as attribute to the + // module. + if (auto cubinAttr = translateGpuModuleToCubinAnnotation( + *llvmModule, module.getLoc(), *module.getName())) + module.setAttr(kCubinAnnotation, cubinAttr); + else + signalPassFailure(); + } + +private: + static OwnedCubin compilePtxToCubinForTesting(const std::string &ptx, + Location, StringRef); + + std::string translateModuleToPtx(llvm::Module &module, + llvm::TargetMachine &target_machine); + + /// Converts llvmModule to cubin using the user-provided generator. Location + /// is used for error reporting and name is forwarded to the CUBIN generator + /// to use in its logging mechanisms. + OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Location loc, + StringRef name); + + /// Translates llvmModule to cubin and returns the result as attribute. + StringAttr translateGpuModuleToCubinAnnotation(llvm::Module &llvmModule, + Location loc, StringRef name); + + CubinGenerator cubinGenerator; +}; + +} // anonymous namespace + +std::string GpuKernelToCubinPass::translateModuleToPtx( + llvm::Module &module, llvm::TargetMachine &target_machine) { + std::string ptx; + { + llvm::raw_string_ostream stream(ptx); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager codegen_passes; + target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr, + llvm::CGFT_AssemblyFile); + codegen_passes.run(module); + } + + return ptx; +} + +OwnedCubin +GpuKernelToCubinPass::compilePtxToCubinForTesting(const std::string &ptx, + Location, StringRef) { + const char data[] = "CUBIN"; + return std::make_unique<std::vector<char>>(data, data + sizeof(data) - 1); +} + +OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule, + Location loc, + StringRef name) { + std::unique_ptr<llvm::TargetMachine> targetMachine; + { + std::string error; + // TODO(herhut): Make triple configurable. + constexpr const char *cudaTriple = "nvptx64-nvidia-cuda"; + llvm::Triple triple(cudaTriple); + const llvm::Target *target = + llvm::TargetRegistry::lookupTarget("", triple, error); + if (target == nullptr) { + emitError(loc, "cannot initialize target triple"); + return {}; + } + targetMachine.reset( + target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {})); + } + + // Set the data layout of the llvm module to match what the ptx target needs. + llvmModule.setDataLayout(targetMachine->createDataLayout()); + + auto ptx = translateModuleToPtx(llvmModule, *targetMachine); + + return cubinGenerator(ptx, loc, name); +} + +StringAttr GpuKernelToCubinPass::translateGpuModuleToCubinAnnotation( + llvm::Module &llvmModule, Location loc, StringRef name) { + auto cubin = convertModuleToCubin(llvmModule, loc, name); + if (!cubin) + return {}; + return StringAttr::get({cubin->data(), cubin->size()}, loc->getContext()); +} + +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) { + return std::make_unique<GpuKernelToCubinPass>(cubinGenerator); +} + +static PassRegistration<GpuKernelToCubinPass> + pass("test-kernel-to-cubin", + "Convert all kernel functions to CUDA cubin blobs"); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp new file mode 100644 index 00000000000..19dabcdafee --- /dev/null +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -0,0 +1,424 @@ +//===- ConvertLaunchFuncToCudaCalls.cpp - MLIR CUDA lowering passes -------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert gpu.launch_func op into a sequence of +// CUDA runtime calls. As the CUDA runtime does not have a stable published ABI, +// this pass uses a slim runtime layer that builds on top of the public API from +// the CUDA headers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h" + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +// To avoid name mangling, these are defined in the mini-runtime file. +static constexpr const char *cuModuleLoadName = "mcuModuleLoad"; +static constexpr const char *cuModuleGetFunctionName = "mcuModuleGetFunction"; +static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel"; +static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper"; +static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize"; +static constexpr const char *kMcuMemHostRegister = "mcuMemHostRegister"; + +static constexpr const char *kCubinAnnotation = "nvvm.cubin"; +static constexpr const char *kCubinStorageSuffix = "_cubin_cst"; + +namespace { + +/// A pass to convert gpu.launch_func operations into a sequence of CUDA +/// runtime calls. +/// +/// In essence, a gpu.launch_func operations gets compiled into the following +/// sequence of runtime calls: +/// +/// * mcuModuleLoad -- loads the module given the cubin data +/// * mcuModuleGetFunction -- gets a handle to the actual kernel function +/// * mcuGetStreamHelper -- initializes a new CUDA stream +/// * mcuLaunchKernelName -- launches the kernel on a stream +/// * mcuStreamSynchronize -- waits for operations on the stream to finish +/// +/// Intermediate data structures are allocated on the stack. +class GpuLaunchFuncToCudaCallsPass + : public ModulePass<GpuLaunchFuncToCudaCallsPass> { +private: + LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } + + llvm::LLVMContext &getLLVMContext() { + return getLLVMDialect()->getLLVMContext(); + } + + void initializeCachedTypes() { + const llvm::Module &module = llvmDialect->getLLVMModule(); + llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); + llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + llvmPointerPointerType = llvmPointerType.getPointerTo(); + llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect); + llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); + llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); + llvmIntPtrType = LLVM::LLVMType::getIntNTy( + llvmDialect, module.getDataLayout().getPointerSizeInBits()); + } + + LLVM::LLVMType getVoidType() { return llvmVoidType; } + + LLVM::LLVMType getPointerType() { return llvmPointerType; } + + LLVM::LLVMType getPointerPointerType() { return llvmPointerPointerType; } + + LLVM::LLVMType getInt8Type() { return llvmInt8Type; } + + LLVM::LLVMType getInt32Type() { return llvmInt32Type; } + + LLVM::LLVMType getInt64Type() { return llvmInt64Type; } + + LLVM::LLVMType getIntPtrType() { + const llvm::Module &module = getLLVMDialect()->getLLVMModule(); + return LLVM::LLVMType::getIntNTy( + getLLVMDialect(), module.getDataLayout().getPointerSizeInBits()); + } + + LLVM::LLVMType getCUResultType() { + // This is declared as an enum in CUDA but helpers use i32. + return getInt32Type(); + } + + // Allocate a void pointer on the stack. + Value allocatePointer(OpBuilder &builder, Location loc) { + auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(), + builder.getI32IntegerAttr(1)); + return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one, + /*alignment=*/0); + } + + void declareCudaFunctions(Location loc); + Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); + Value generateKernelNameConstant(StringRef name, Location loc, + OpBuilder &builder); + void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); + +public: + // Run the dialect converter on the module. + void runOnModule() override { + // Cache the LLVMDialect for the current module. + llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); + // Cache the used LLVM types. + initializeCachedTypes(); + + getModule().walk([this](mlir::gpu::LaunchFuncOp op) { + translateGpuLaunchCalls(op); + }); + + // GPU kernel modules are no longer necessary since we have a global + // constant with the CUBIN data. + for (auto m : llvm::make_early_inc_range(getModule().getOps<ModuleOp>())) + if (m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName())) + m.erase(); + } + +private: + LLVM::LLVMDialect *llvmDialect; + LLVM::LLVMType llvmVoidType; + LLVM::LLVMType llvmPointerType; + LLVM::LLVMType llvmPointerPointerType; + LLVM::LLVMType llvmInt8Type; + LLVM::LLVMType llvmInt32Type; + LLVM::LLVMType llvmInt64Type; + LLVM::LLVMType llvmIntPtrType; +}; + +} // anonymous namespace + +// Adds declarations for the needed helper functions from the CUDA wrapper. +// The types in comments give the actual types expected/returned but the API +// uses void pointers. This is fine as they have the same linkage in C. +void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { + ModuleOp module = getModule(); + OpBuilder builder(module.getBody()->getTerminator()); + if (!module.lookupSymbol(cuModuleLoadName)) { + builder.create<LLVM::LLVMFuncOp>( + loc, cuModuleLoadName, + LLVM::LLVMType::getFunctionTy( + getCUResultType(), + { + getPointerPointerType(), /* CUmodule *module */ + getPointerType() /* void *cubin */ + }, + /*isVarArg=*/false)); + } + if (!module.lookupSymbol(cuModuleGetFunctionName)) { + // The helper uses void* instead of CUDA's opaque CUmodule and + // CUfunction. + builder.create<LLVM::LLVMFuncOp>( + loc, cuModuleGetFunctionName, + LLVM::LLVMType::getFunctionTy( + getCUResultType(), + { + getPointerPointerType(), /* void **function */ + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + /*isVarArg=*/false)); + } + if (!module.lookupSymbol(cuLaunchKernelName)) { + // Other than the CUDA api, the wrappers use uintptr_t to match the + // LLVM type if MLIR's index type, which the GPU dialect uses. + // Furthermore, they use void* instead of CUDA's opaque CUfunction and + // CUstream. + builder.create<LLVM::LLVMFuncOp>( + loc, cuLaunchKernelName, + LLVM::LLVMType::getFunctionTy( + getCUResultType(), + { + getPointerType(), /* void* f */ + getIntPtrType(), /* intptr_t gridXDim */ + getIntPtrType(), /* intptr_t gridyDim */ + getIntPtrType(), /* intptr_t gridZDim */ + getIntPtrType(), /* intptr_t blockXDim */ + getIntPtrType(), /* intptr_t blockYDim */ + getIntPtrType(), /* intptr_t blockZDim */ + getInt32Type(), /* unsigned int sharedMemBytes */ + getPointerType(), /* void *hstream */ + getPointerPointerType(), /* void **kernelParams */ + getPointerPointerType() /* void **extra */ + }, + /*isVarArg=*/false)); + } + if (!module.lookupSymbol(cuGetStreamHelperName)) { + // Helper function to get the current CUDA stream. Uses void* instead of + // CUDAs opaque CUstream. + builder.create<LLVM::LLVMFuncOp>( + loc, cuGetStreamHelperName, + LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false)); + } + if (!module.lookupSymbol(cuStreamSynchronizeName)) { + builder.create<LLVM::LLVMFuncOp>( + loc, cuStreamSynchronizeName, + LLVM::LLVMType::getFunctionTy(getCUResultType(), + getPointerType() /* CUstream stream */, + /*isVarArg=*/false)); + } + if (!module.lookupSymbol(kMcuMemHostRegister)) { + builder.create<LLVM::LLVMFuncOp>( + loc, kMcuMemHostRegister, + LLVM::LLVMType::getFunctionTy(getVoidType(), + { + getPointerType(), /* void *ptr */ + getInt64Type() /* int64 sizeBytes*/ + }, + /*isVarArg=*/false)); + } +} + +// Generates a parameters array to be used with a CUDA kernel launch call. The +// arguments are extracted from the launchOp. +// The generated code is essentially as follows: +// +// %array = alloca(numparams * sizeof(void *)) +// for (i : [0, NumKernelOperands)) +// %array[i] = cast<void*>(KernelOperand[i]) +// return %array +Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, + OpBuilder &builder) { + auto numKernelOperands = launchOp.getNumKernelOperands(); + Location loc = launchOp.getLoc(); + auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(), + builder.getI32IntegerAttr(1)); + // Provision twice as much for the `array` to allow up to one level of + // indirection for each argument. + auto arraySize = builder.create<LLVM::ConstantOp>( + loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands)); + auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), + arraySize, /*alignment=*/0); + for (unsigned idx = 0; idx < numKernelOperands; ++idx) { + auto operand = launchOp.getKernelOperand(idx); + auto llvmType = operand->getType().cast<LLVM::LLVMType>(); + Value memLocation = builder.create<LLVM::AllocaOp>( + loc, llvmType.getPointerTo(), one, /*alignment=*/1); + builder.create<LLVM::StoreOp>(loc, operand, memLocation); + auto casted = + builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation); + + // Assume all struct arguments come from MemRef. If this assumption does not + // hold anymore then we `launchOp` to lower from MemRefType and not after + // LLVMConversion has taken place and the MemRef information is lost. + // Extra level of indirection in the `array`: + // the descriptor pointer is registered via @mcuMemHostRegisterPtr + if (llvmType.isStructTy()) { + auto registerFunc = + getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister); + auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo()); + auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(), + ArrayRef<Value>{nullPtr, one}); + auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep); + builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, + builder.getSymbolRefAttr(registerFunc), + ArrayRef<Value>{casted, size}); + Value memLocation = builder.create<LLVM::AllocaOp>( + loc, getPointerPointerType(), one, /*alignment=*/1); + builder.create<LLVM::StoreOp>(loc, casted, memLocation); + casted = + builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation); + } + + auto index = builder.create<LLVM::ConstantOp>( + loc, getInt32Type(), builder.getI32IntegerAttr(idx)); + auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array, + ArrayRef<Value>{index}); + builder.create<LLVM::StoreOp>(loc, casted, gep); + } + return array; +} + +// Generates an LLVM IR dialect global that contains the name of the given +// kernel function as a C string, and returns a pointer to its beginning. +// The code is essentially: +// +// llvm.global constant @kernel_name("function_name\00") +// func(...) { +// %0 = llvm.addressof @kernel_name +// %1 = llvm.constant (0 : index) +// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> +// } +Value GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( + StringRef name, Location loc, OpBuilder &builder) { + // Make sure the trailing zero is included in the constant. + std::vector<char> kernelName(name.begin(), name.end()); + kernelName.push_back('\0'); + + std::string globalName = llvm::formatv("{0}_kernel_name", name); + return LLVM::createGlobalString( + loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), + LLVM::Linkage::Internal, llvmDialect); +} + +// Emits LLVM IR to launch a kernel function. Expects the module that contains +// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute of the +// kernel function in the IR. +// While MLIR has no global constants, also expects a cubin getter function in +// an 'nvvm.cubingetter' attribute. Such function is expected to return a +// pointer to the cubin blob when invoked. +// With these given, the generated code in essence is +// +// %0 = call %cubingetter +// %1 = alloca sizeof(void*) +// call %mcuModuleLoad(%2, %1) +// %2 = alloca sizeof(void*) +// %3 = load %1 +// %4 = <see generateKernelNameConstant> +// call %mcuModuleGetFunction(%2, %3, %4) +// %5 = call %mcuGetStreamHelper() +// %6 = load %2 +// %7 = <see setupParamsArray> +// call %mcuLaunchKernel(%6, <launchOp operands 0..5>, 0, %5, %7, nullptr) +// call %mcuStreamSynchronize(%5) +void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( + mlir::gpu::LaunchFuncOp launchOp) { + OpBuilder builder(launchOp); + Location loc = launchOp.getLoc(); + declareCudaFunctions(loc); + + auto zero = builder.create<LLVM::ConstantOp>(loc, getInt32Type(), + builder.getI32IntegerAttr(0)); + // Create an LLVM global with CUBIN extracted from the kernel annotation and + // obtain a pointer to the first byte in it. + auto kernelModule = + getModule().lookupSymbol<ModuleOp>(launchOp.getKernelModuleName()); + assert(kernelModule && "expected a kernel module"); + + auto cubinAttr = kernelModule.getAttrOfType<StringAttr>(kCubinAnnotation); + if (!cubinAttr) { + kernelModule.emitOpError() + << "missing " << kCubinAnnotation << " attribute"; + return signalPassFailure(); + } + + assert(kernelModule.getName() && "expected a named module"); + SmallString<128> nameBuffer(*kernelModule.getName()); + nameBuffer.append(kCubinStorageSuffix); + Value data = LLVM::createGlobalString( + loc, builder, nameBuffer.str(), cubinAttr.getValue(), + LLVM::Linkage::Internal, getLLVMDialect()); + + // Emit the load module call to load the module data. Error checking is done + // in the called helper function. + auto cuModule = allocatePointer(builder, loc); + auto cuModuleLoad = + getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName); + builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, + builder.getSymbolRefAttr(cuModuleLoad), + ArrayRef<Value>{cuModule, data}); + // Get the function from the module. The name corresponds to the name of + // the kernel function. + auto cuOwningModuleRef = + builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule); + auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder); + auto cuFunction = allocatePointer(builder, loc); + auto cuModuleGetFunction = + getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName); + builder.create<LLVM::CallOp>( + loc, ArrayRef<Type>{getCUResultType()}, + builder.getSymbolRefAttr(cuModuleGetFunction), + ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName}); + // Grab the global stream needed for execution. + auto cuGetStreamHelper = + getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName); + auto cuStream = builder.create<LLVM::CallOp>( + loc, ArrayRef<Type>{getPointerType()}, + builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{}); + // Invoke the function with required arguments. + auto cuLaunchKernel = + getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName); + auto cuFunctionRef = + builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction); + auto paramsArray = setupParamsArray(launchOp, builder); + auto nullpointer = + builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero); + builder.create<LLVM::CallOp>( + loc, ArrayRef<Type>{getCUResultType()}, + builder.getSymbolRefAttr(cuLaunchKernel), + ArrayRef<Value>{cuFunctionRef, launchOp.getOperand(0), + launchOp.getOperand(1), launchOp.getOperand(2), + launchOp.getOperand(3), launchOp.getOperand(4), + launchOp.getOperand(5), zero, /* sharedMemBytes */ + cuStream.getResult(0), /* stream */ + paramsArray, /* kernel params */ + nullpointer /* extra */}); + // Sync on the stream to make it synchronous. + auto cuStreamSync = + getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName); + builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, + builder.getSymbolRefAttr(cuStreamSync), + ArrayRef<Value>(cuStream.getResult(0))); + launchOp.erase(); +} + +std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> +mlir::createConvertGpuLaunchFuncToCudaCallsPass() { + return std::make_unique<GpuLaunchFuncToCudaCallsPass>(); +} + +static PassRegistration<GpuLaunchFuncToCudaCallsPass> + pass("launch-func-to-cuda", + "Convert all launch_func ops to CUDA runtime calls"); diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt new file mode 100644 index 00000000000..b5df446abe1 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS GPUToNVVM.td) +mlir_tablegen(GPUToNVVM.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRGPUToNVVMIncGen) + +add_llvm_library(MLIRGPUtoNVVMTransforms + LowerGpuOpsToNVVMOps.cpp + ) + +add_dependencies(MLIRGPUtoNVVMTransforms + MLIRGPUToNVVMIncGen) + +target_link_libraries(MLIRGPUtoNVVMTransforms + LLVMSupport + MLIRGPU + MLIRLLVMIR + MLIRNVVMIR + MLIRPass + ) diff --git a/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td b/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td new file mode 100644 index 00000000000..0a6aec07041 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td @@ -0,0 +1,21 @@ +//==-- GPUToNVVM.td - GPU Ops to NVVM Patterns ---------------*- tablegen -*==// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines Patterns to lower GPU ops to NVVM. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_GPUTONVVM_TD +#define MLIR_CONVERSION_GPUTONVVM_TD + +include "mlir/Dialect/GPU/GPUOps.td" +include "mlir/Dialect/LLVMIR/NVVMOps.td" + +def : Pat<(GPU_BarrierOp), (NVVM_Barrier0Op)>; + +#endif // MLIR_CONVERSION_GPUTONVVM_TD diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp new file mode 100644 index 00000000000..08c18c1ec83 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -0,0 +1,751 @@ +//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to generate NVVMIR operations for higher-level +// GPU operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/FormatVariadic.h" + +#include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" + +using namespace mlir; + +namespace { + +/// Derived type converter for GPU to NVVM lowering. The GPU dialect uses memory +/// space 5 for private memory attributions, but NVVM represents private +/// memory allocations as local `alloca`s in the default address space. This +/// converter drops the private memory space to support the use case above. +class NVVMTypeConverter : public LLVMTypeConverter { +public: + using LLVMTypeConverter::LLVMTypeConverter; + + Type convertType(Type type) override { + auto memref = type.dyn_cast<MemRefType>(); + if (memref && + memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) { + type = MemRefType::get(memref.getShape(), memref.getElementType(), + memref.getAffineMaps()); + } + + return LLVMTypeConverter::convertType(type); + } +}; + +/// Converts all_reduce op to LLVM/NVVM ops. +struct GPUAllReduceOpLowering : public LLVMOpLowering { + using AccumulatorFactory = + std::function<Value(Location, Value, Value, ConversionPatternRewriter &)>; + + explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(gpu::AllReduceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), + int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value operand = operands.front(); + + // TODO(csigg): Generalize to other types of accumulation. + assert(op->getOperand(0)->getType().isIntOrFloat()); + + // Create the reduction using an accumulator factory. + AccumulatorFactory factory = + getFactory(cast<gpu::AllReduceOp>(op), operand); + assert(factory && "failed to create accumulator factory"); + Value result = createBlockReduce(loc, operand, factory, rewriter); + + rewriter.replaceOp(op, {result}); + return matchSuccess(); + } + +private: + /// Returns an accumulator factory using either the op attribute or the body + /// region. + AccumulatorFactory getFactory(gpu::AllReduceOp allReduce, + Value operand) const { + if (!allReduce.body().empty()) { + return getFactory(allReduce.body()); + } + if (allReduce.op()) { + auto type = operand->getType().cast<LLVM::LLVMType>(); + return getFactory(*allReduce.op(), type.getUnderlyingType()); + } + return AccumulatorFactory(); + } + + /// Returns an accumulator factory that clones the body. The body's entry + /// block is expected to have 2 arguments. The gpu.yield return the + /// accumulated value of the same type. + AccumulatorFactory getFactory(Region &body) const { + return AccumulatorFactory([&](Location loc, Value lhs, Value rhs, + ConversionPatternRewriter &rewriter) { + Block *block = rewriter.getInsertionBlock(); + Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); + + // Insert accumulator body between split block. + BlockAndValueMapping mapping; + mapping.map(body.front().getArgument(0), lhs); + mapping.map(body.front().getArgument(1), rhs); + rewriter.cloneRegionBefore(body, *split->getParent(), + split->getIterator(), mapping); + + // Add branch before inserted body, into body. + block = block->getNextNode(); + rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{}, + llvm::makeArrayRef(block), ValueRange()); + + // Replace all gpu.yield ops with branch out of body. + for (; block != split; block = block->getNextNode()) { + Operation *terminator = block->getTerminator(); + if (!llvm::isa<gpu::YieldOp>(terminator)) + continue; + rewriter.setInsertionPointToEnd(block); + rewriter.replaceOpWithNewOp<LLVM::BrOp>( + terminator, ArrayRef<Value>{}, llvm::makeArrayRef(split), + ValueRange(terminator->getOperand(0))); + } + + // Return accumulator result. + rewriter.setInsertionPointToStart(split); + return split->addArgument(lhs->getType()); + }); + } + + /// Returns an accumulator factory that creates an op specified by opName. + AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const { + if (type->isVectorTy() || type->isArrayTy()) + return getFactory(opName, type->getSequentialElementType()); + + bool isFloatingPoint = type->isFloatingPointTy(); + + if (opName == "add") { + return isFloatingPoint ? getFactory<LLVM::FAddOp>() + : getFactory<LLVM::AddOp>(); + } + if (opName == "mul") { + return isFloatingPoint ? getFactory<LLVM::FMulOp>() + : getFactory<LLVM::MulOp>(); + } + + return AccumulatorFactory(); + } + + /// Returns an accumulator factory that creates an op of type T. + template <typename T> AccumulatorFactory getFactory() const { + return [](Location loc, Value lhs, Value rhs, + ConversionPatternRewriter &rewriter) { + return rewriter.create<T>(loc, lhs->getType(), lhs, rhs); + }; + } + + /// Creates an all_reduce across the block. + /// + /// First reduce the elements within a warp. The first thread of each warp + /// writes the intermediate result to shared memory. After synchronizing the + /// block, the first warp reduces the values from shared memory. The result + /// is broadcasted to all threads through shared memory. + /// + /// %warp_reduce = `createWarpReduce(%operand)` + /// %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer + /// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32 + /// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32 + /// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1 + /// %thread_idx = `getLinearThreadIndex()` : !llvm.i32 + /// llvm.cond_br %is_first_lane, ^then1, ^continue1 + /// ^then1: + /// %warp_id = `getWarpId()` + /// %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id] + /// llvm.store %store_dst, %warp_reduce + /// llvm.br ^continue1 + /// ^continue1: + /// nvvm.barrier0 + /// %num_warps = `getNumWarps()` : !llvm.i32 + /// %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps + /// %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero] + /// llvm.cond_br %is_first_lane, ^then2, ^continue2 + /// ^then2: + /// %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx] + /// %value = llvm.load %load_src + /// %result = `createWarpReduce(%value)` + /// llvm.store %result_ptr, %result + /// llvm.br ^continue2 + /// ^continue2: + /// nvvm.barrier0 + /// %result = llvm.load %result_ptr + /// return %result + /// + Value createBlockReduce(Location loc, Value operand, + AccumulatorFactory &accumFactory, + ConversionPatternRewriter &rewriter) const { + auto type = operand->getType().cast<LLVM::LLVMType>(); + + // Create shared memory array to store the warp reduction. + auto module = operand->getDefiningOp()->getParentOfType<ModuleOp>(); + assert(module && "op must belong to a module"); + Value sharedMemPtr = + createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); + + Value zero = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(0u)); + Value laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type); + Value isFirstLane = rewriter.create<LLVM::ICmpOp>( + loc, LLVM::ICmpPredicate::eq, laneId, zero); + Value threadIdx = getLinearThreadIndex(loc, rewriter); + Value blockSize = getBlockSize(loc, rewriter); + Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); + + // Reduce elements within each warp to produce the intermediate results. + Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, + accumFactory, rewriter); + + // Write the intermediate results to shared memory, using the first lane of + // each warp. + createPredicatedBlock(loc, rewriter, isFirstLane, [&] { + Value warpId = getDivideByWarpSize(threadIdx, rewriter); + Value storeDst = rewriter.create<LLVM::GEPOp>( + loc, type, sharedMemPtr, ArrayRef<Value>({zero, warpId})); + rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst); + }); + rewriter.create<NVVM::Barrier0Op>(loc); + + Value numWarps = getNumWarps(loc, blockSize, rewriter); + Value isValidWarp = rewriter.create<LLVM::ICmpOp>( + loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); + Value resultPtr = rewriter.create<LLVM::GEPOp>( + loc, type, sharedMemPtr, ArrayRef<Value>({zero, zero})); + + // Use the first numWarps threads to reduce the intermediate results from + // shared memory. The final result is written to shared memory again. + createPredicatedBlock(loc, rewriter, isValidWarp, [&] { + Value loadSrc = rewriter.create<LLVM::GEPOp>( + loc, type, sharedMemPtr, ArrayRef<Value>({zero, threadIdx})); + Value value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc); + Value result = createWarpReduce(loc, numWarps, laneId, value, + accumFactory, rewriter); + rewriter.create<LLVM::StoreOp>(loc, result, resultPtr); + }); + rewriter.create<NVVM::Barrier0Op>(loc); + + // Load and return result from shared memory. + Value result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr); + return result; + } + + /// Creates an if-block skeleton and calls the two factories to generate the + /// ops in the `then` and `else` block.. + /// + /// llvm.cond_br %condition, ^then, ^continue + /// ^then: + /// %then_operands = `thenOpsFactory()` + /// llvm.br ^continue(%then_operands) + /// ^else: + /// %else_operands = `elseOpsFactory()` + /// llvm.br ^continue(%else_operands) + /// ^continue(%block_operands): + /// + template <typename ThenOpsFactory, typename ElseOpsFactory> + void createIf(Location loc, ConversionPatternRewriter &rewriter, + Value condition, ThenOpsFactory &&thenOpsFactory, + ElseOpsFactory &&elseOpsFactory) const { + Block *currentBlock = rewriter.getInsertionBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + + Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); + Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); + Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); + + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create<LLVM::CondBrOp>(loc, llvm::makeArrayRef(condition), + ArrayRef<Block *>{thenBlock, elseBlock}); + + auto addBranch = [&](ValueRange operands) { + rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{}, + llvm::makeArrayRef(continueBlock), + llvm::makeArrayRef(operands)); + }; + + rewriter.setInsertionPointToStart(thenBlock); + auto thenOperands = thenOpsFactory(); + addBranch(thenOperands); + + rewriter.setInsertionPointToStart(elseBlock); + auto elseOperands = elseOpsFactory(); + addBranch(elseOperands); + + assert(thenOperands.size() == elseOperands.size()); + rewriter.setInsertionPointToStart(continueBlock); + for (auto operand : thenOperands) + continueBlock->addArgument(operand->getType()); + } + + /// Shortcut for createIf with empty else block and no block operands. + template <typename Factory> + void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, + Value condition, + Factory &&predicatedOpsFactory) const { + createIf( + loc, rewriter, condition, + [&] { + predicatedOpsFactory(); + return ArrayRef<Value>(); + }, + [&] { return ArrayRef<Value>(); }); + } + + /// Creates a reduction across the first activeWidth lanes of a warp. + /// The first lane returns the result, all others return values are undefined. + Value createWarpReduce(Location loc, Value activeWidth, Value laneId, + Value operand, AccumulatorFactory accumFactory, + ConversionPatternRewriter &rewriter) const { + Value warpSize = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); + Value isPartialWarp = rewriter.create<LLVM::ICmpOp>( + loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); + auto type = operand->getType().cast<LLVM::LLVMType>(); + + createIf( + loc, rewriter, isPartialWarp, + // Generate reduction over a (potentially) partial warp. + [&] { + Value value = operand; + Value one = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + // Bit mask of active lanes: `(1 << activeWidth) - 1`. + Value activeMask = rewriter.create<LLVM::SubOp>( + loc, int32Type, + rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth), + one); + // Clamp lane: `activeWidth - 1` + Value maskAndClamp = + rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one); + auto dialect = lowering.getDialect(); + auto predTy = LLVM::LLVMType::getInt1Ty(dialect); + auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); + auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); + + // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source + // lane is within the active range. All lanes contain the final + // result, but only the first lane's result is used. + for (int i = 1; i < kWarpSize; i <<= 1) { + Value offset = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(i)); + Value shfl = rewriter.create<NVVM::ShflBflyOp>( + loc, shflTy, activeMask, value, offset, maskAndClamp, + returnValueAndIsValidAttr); + Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>( + loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); + // Skip the accumulation if the shuffle op read from a lane outside + // of the active range. + createIf( + loc, rewriter, isActiveSrcLane, + [&] { + Value shflValue = rewriter.create<LLVM::ExtractValueOp>( + loc, type, shfl, rewriter.getIndexArrayAttr(0)); + return SmallVector<Value, 1>{ + accumFactory(loc, value, shflValue, rewriter)}; + }, + [&] { return llvm::makeArrayRef(value); }); + value = rewriter.getInsertionBlock()->getArgument(0); + } + return SmallVector<Value, 1>{value}; + }, + // Generate a reduction over the entire warp. This is a specialization + // of the above reduction with unconditional accumulation. + [&] { + Value value = operand; + Value activeMask = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(~0u)); + Value maskAndClamp = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + for (int i = 1; i < kWarpSize; i <<= 1) { + Value offset = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(i)); + Value shflValue = rewriter.create<NVVM::ShflBflyOp>( + loc, type, activeMask, value, offset, maskAndClamp, + /*return_value_and_is_valid=*/UnitAttr()); + value = accumFactory(loc, value, shflValue, rewriter); + } + return SmallVector<Value, 1>{value}; + }); + return rewriter.getInsertionBlock()->getArgument(0); + } + + /// Creates a global array stored in shared memory. + Value createSharedMemoryArray(Location loc, ModuleOp module, + LLVM::LLVMType elementType, int numElements, + ConversionPatternRewriter &rewriter) const { + OpBuilder builder(module.getBodyRegion()); + + auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); + StringRef name = "reduce_buffer"; + auto globalOp = builder.create<LLVM::GlobalOp>( + loc, arrayType.cast<LLVM::LLVMType>(), + /*isConstant=*/false, LLVM::Linkage::Internal, name, + /*value=*/Attribute(), gpu::GPUDialect::getWorkgroupAddressSpace()); + + return rewriter.create<LLVM::AddressOfOp>(loc, globalOp); + } + + /// Returns the index of the thread within the block. + Value getLinearThreadIndex(Location loc, + ConversionPatternRewriter &rewriter) const { + Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type); + Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type); + Value idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type); + Value idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type); + Value idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type); + Value tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY); + Value tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY); + Value tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX); + return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX); + } + + /// Returns the number of threads in the block. + Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { + Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type); + Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type); + Value dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type); + Value dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY); + return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ); + } + + /// Returns the number of warps in the block. + Value getNumWarps(Location loc, Value blockSize, + ConversionPatternRewriter &rewriter) const { + auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); + auto biasedBlockSize = rewriter.create<LLVM::AddOp>( + loc, int32Type, blockSize, warpSizeMinusOne); + return getDivideByWarpSize(biasedBlockSize, rewriter); + } + + /// Returns the number of active threads in the warp, not clamped to 32. + Value getActiveWidth(Location loc, Value threadIdx, Value blockSize, + ConversionPatternRewriter &rewriter) const { + Value threadIdxMask = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); + Value numThreadsWithSmallerWarpId = + rewriter.create<LLVM::AndOp>(loc, threadIdx, threadIdxMask); + return rewriter.create<LLVM::SubOp>(loc, blockSize, + numThreadsWithSmallerWarpId); + } + + /// Returns value divided by the warp size (i.e. 32). + Value getDivideByWarpSize(Value value, + ConversionPatternRewriter &rewriter) const { + auto loc = value->getLoc(); + auto warpSize = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); + return rewriter.create<LLVM::SDivOp>(loc, int32Type, value, warpSize); + } + + LLVM::LLVMType int32Type; + + static constexpr int kWarpSize = 32; +}; + +struct GPUShuffleOpLowering : public LLVMOpLowering { + explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(gpu::ShuffleOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_) {} + + /// Lowers a shuffle to the corresponding NVVM op. + /// + /// Convert the `width` argument into an activeMask (a bitmask which specifies + /// which threads participate in the shuffle) and a maskAndClamp (specifying + /// the highest lane which participates in the shuffle). + /// + /// %one = llvm.constant(1 : i32) : !llvm.i32 + /// %shl = llvm.shl %one, %width : !llvm.i32 + /// %active_mask = llvm.sub %shl, %one : !llvm.i32 + /// %mask_and_clamp = llvm.sub %width, %one : !llvm.i32 + /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset, + /// %mask_and_clamp : !llvm<"{ float, i1 }"> + /// %shfl_value = llvm.extractvalue %shfl[0 : index] : + /// !llvm<"{ float, i1 }"> + /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : + /// !llvm<"{ float, i1 }"> + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + gpu::ShuffleOpOperandAdaptor adaptor(operands); + + auto dialect = lowering.getDialect(); + auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>(); + auto int32Type = LLVM::LLVMType::getInt32Ty(dialect); + auto predTy = LLVM::LLVMType::getInt1Ty(dialect); + auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy}); + + Value one = rewriter.create<LLVM::ConstantOp>( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + // Bit mask of active lanes: `(1 << activeWidth) - 1`. + Value activeMask = rewriter.create<LLVM::SubOp>( + loc, int32Type, + rewriter.create<LLVM::ShlOp>(loc, int32Type, one, adaptor.width()), + one); + // Clamp lane: `activeWidth - 1` + Value maskAndClamp = + rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one); + + auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); + Value shfl = rewriter.create<NVVM::ShflBflyOp>( + loc, resultTy, activeMask, adaptor.value(), adaptor.offset(), + maskAndClamp, returnValueAndIsValidAttr); + Value shflValue = rewriter.create<LLVM::ExtractValueOp>( + loc, valueTy, shfl, rewriter.getIndexArrayAttr(0)); + Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>( + loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); + + rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); + return matchSuccess(); + } +}; + +struct GPUFuncOpLowering : LLVMOpLowering { + explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter) + : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.empty() && "func op is not expected to have operands"); + auto gpuFuncOp = cast<gpu::GPUFuncOp>(op); + Location loc = gpuFuncOp.getLoc(); + + SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; + workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); + for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { + Value attribution = en.value(); + + auto type = attribution->getType().dyn_cast<MemRefType>(); + assert(type && type.hasStaticShape() && "unexpected type in attribution"); + + uint64_t numElements = type.getNumElements(); + + auto elementType = + lowering.convertType(type.getElementType()).cast<LLVM::LLVMType>(); + auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); + std::string name = + llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()); + auto globalOp = rewriter.create<LLVM::GlobalOp>( + gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + LLVM::Linkage::Internal, name, /*value=*/Attribute(), + gpu::GPUDialect::getWorkgroupAddressSpace()); + workgroupBuffers.push_back(globalOp); + } + + // Rewrite the original GPU function to an LLVM function. + auto funcType = lowering.convertType(gpuFuncOp.getType()) + .cast<LLVM::LLVMType>() + .getPointerElementTy(); + + // Remap proper input types. + TypeConverter::SignatureConversion signatureConversion( + gpuFuncOp.front().getNumArguments()); + for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i) + signatureConversion.addInputs(i, funcType.getFunctionParamType(i)); + + // Create the new function operation. Only copy those attributes that are + // not specific to function modeling. + SmallVector<NamedAttribute, 4> attributes; + for (const auto &attr : gpuFuncOp.getAttrs()) { + if (attr.first.is(SymbolTable::getSymbolAttrName()) || + attr.first.is(impl::getTypeAttrName()) || + attr.first.is(gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())) + continue; + attributes.push_back(attr); + } + auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( + gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + LLVM::Linkage::External, attributes); + + { + // Insert operations that correspond to converted workgroup and private + // memory attributions to the body of the function. This must operate on + // the original function, before the body region is inlined in the new + // function to maintain the relation between block arguments and the + // parent operation that assigns their semantics. + OpBuilder::InsertionGuard guard(rewriter); + + // Rewrite workgroup memory attributions to addresses of global buffers. + rewriter.setInsertionPointToStart(&gpuFuncOp.front()); + unsigned numProperArguments = gpuFuncOp.getNumArguments(); + auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); + + Value zero = nullptr; + if (!workgroupBuffers.empty()) + zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, + rewriter.getI32IntegerAttr(0)); + for (auto en : llvm::enumerate(workgroupBuffers)) { + LLVM::GlobalOp global = en.value(); + Value address = rewriter.create<LLVM::AddressOfOp>(loc, global); + auto elementType = global.getType().getArrayElementType(); + Value memory = rewriter.create<LLVM::GEPOp>( + loc, elementType.getPointerTo(global.addr_space().getZExtValue()), + address, ArrayRef<Value>{zero, zero}); + + // Build a memref descriptor pointing to the buffer to plug with the + // existing memref infrastructure. This may use more registers than + // otherwise necessary given that memref sizes are fixed, but we can try + // and canonicalize that away later. + Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; + auto type = attribution->getType().cast<MemRefType>(); + auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, + type, memory); + signatureConversion.remapInput(numProperArguments + en.index(), descr); + } + + // Rewrite private memory attributions to alloca'ed buffers. + unsigned numWorkgroupAttributions = + gpuFuncOp.getNumWorkgroupAttributions(); + auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { + Value attribution = en.value(); + auto type = attribution->getType().cast<MemRefType>(); + assert(type && type.hasStaticShape() && + "unexpected type in attribution"); + + // Explicitly drop memory space when lowering private memory + // attributions since NVVM models it as `alloca`s in the default + // memory space and does not support `alloca`s with addrspace(5). + auto ptrType = lowering.convertType(type.getElementType()) + .cast<LLVM::LLVMType>() + .getPointerTo(); + Value numElements = rewriter.create<LLVM::ConstantOp>( + gpuFuncOp.getLoc(), int64Ty, + rewriter.getI64IntegerAttr(type.getNumElements())); + Value allocated = rewriter.create<LLVM::AllocaOp>( + gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); + auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, + type, allocated); + signatureConversion.remapInput( + numProperArguments + numWorkgroupAttributions + en.index(), descr); + } + } + + // Move the region to the new function, update the entry block signature. + rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), + llvmFuncOp.end()); + rewriter.applySignatureConversion(&llvmFuncOp.getBody(), + signatureConversion); + + { + // For memref-typed arguments, insert the relevant loads in the beginning + // of the block to comply with the LLVM dialect calling convention. This + // needs to be done after signature conversion to get the right types. + OpBuilder::InsertionGuard guard(rewriter); + Block &block = llvmFuncOp.front(); + rewriter.setInsertionPointToStart(&block); + + for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) { + if (!en.value().isa<MemRefType>() && + !en.value().isa<UnrankedMemRefType>()) + continue; + + BlockArgument arg = block.getArgument(en.index()); + Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); + rewriter.replaceUsesOfBlockArgument(arg, loaded); + } + } + + rewriter.eraseOp(gpuFuncOp); + return matchSuccess(); + } +}; + +struct GPUReturnOpLowering : public LLVMOpLowering { + GPUReturnOpLowering(LLVMTypeConverter &typeConverter) + : LLVMOpLowering(gpu::ReturnOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands, + ArrayRef<Block *>()); + return matchSuccess(); + } +}; + +/// Import the GPU Ops to NVVM Patterns. +#include "GPUToNVVM.cpp.inc" + +/// A pass that replaces all occurrences of GPU device operations with their +/// corresponding NVVM equivalent. +/// +/// This pass only handles device code and is not meant to be run on GPU host +/// code. +class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> { +public: + void runOnModule() override { + ModuleOp m = getModule(); + if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName())) + return; + + OwningRewritePatternList patterns; + NVVMTypeConverter converter(m.getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + populateGpuToNVVMConversionPatterns(converter, patterns); + ConversionTarget target(getContext()); + target.addIllegalDialect<gpu::GPUDialect>(); + target.addIllegalOp<LLVM::ExpOp>(); + target.addIllegalOp<FuncOp>(); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addLegalDialect<NVVM::NVVMDialect>(); + // TODO(csigg): Remove once we support replacing non-root ops. + target.addLegalOp<gpu::YieldOp>(); + if (failed(applyPartialConversion(m, target, patterns, &converter))) + signalPassFailure(); + } +}; + +} // anonymous namespace + +void mlir::populateGpuToNVVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateWithGenerated(converter.getDialect()->getContext(), &patterns); + patterns + .insert<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp, + NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>, + GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp, + NVVM::BlockDimYOp, NVVM::BlockDimZOp>, + GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp, + NVVM::BlockIdYOp, NVVM::BlockIdZOp>, + GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp, + NVVM::GridDimYOp, NVVM::GridDimZOp>, + GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering, + GPUReturnOpLowering>(converter); + patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf", + "__nv_exp"); +} + +std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() { + return std::make_unique<LowerGpuOpsToNVVMOpsPass>(); +} + +static PassRegistration<LowerGpuOpsToNVVMOpsPass> + pass("convert-gpu-to-nvvm", "Generate NVVM operations for gpu operations"); diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt new file mode 100644 index 00000000000..3c97e5ca86b --- /dev/null +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -0,0 +1,10 @@ +add_llvm_library(MLIRGPUtoROCDLTransforms + LowerGpuOpsToROCDLOps.cpp + ) +target_link_libraries(MLIRGPUtoROCDLTransforms + LLVMSupport + MLIRGPU + MLIRLLVMIR + MLIRROCDLIR + MLIRPass + ) diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp new file mode 100644 index 00000000000..83770641bd4 --- /dev/null +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -0,0 +1,75 @@ +//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to generate ROCDLIR operations for higher-level +// GPU operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" + +using namespace mlir; + +namespace { + +// A pass that replaces all occurrences of GPU device operations with their +// corresponding ROCDL equivalent. +// +// This pass only handles device code and is not meant to be run on GPU host +// code. +class LowerGpuOpsToROCDLOpsPass : public ModulePass<LowerGpuOpsToROCDLOpsPass> { +public: + void runOnModule() override { + ModuleOp m = getModule(); + if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName())) + return; + + OwningRewritePatternList patterns; + LLVMTypeConverter converter(m.getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + patterns.insert< + GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp, + ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>, + GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp, + ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>, + GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, ROCDL::BlockIdXOp, + ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>, + GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp, + ROCDL::GridDimYOp, ROCDL::GridDimZOp>>( + converter); + patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "_ocml_exp_f32", + "_ocml_exp_f64"); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>(); + target.addIllegalOp<LLVM::ExpOp>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed(applyPartialConversion(m, target, patterns, &converter))) + signalPassFailure(); + } +}; + +} // anonymous namespace + +std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToROCDLOpsPass() { + return std::make_unique<LowerGpuOpsToROCDLOpsPass>(); +} + +static PassRegistration<LowerGpuOpsToROCDLOpsPass> + pass("convert-gpu-to-rocdl", + "Generate ROCDL operations for gpu operations"); diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt new file mode 100644 index 00000000000..be82894461d --- /dev/null +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRGPUtoSPIRVTransforms + ConvertGPUToSPIRV.cpp + ConvertGPUToSPIRVPass.cpp + ) + +target_link_libraries(MLIRGPUtoSPIRVTransforms + MLIRGPU + MLIRIR + MLIRPass + MLIRSPIRV + MLIRStandardOps + MLIRStandardToSPIRVTransforms + MLIRSupport + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp new file mode 100644 index 00000000000..509457d076a --- /dev/null +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -0,0 +1,359 @@ +//===- ConvertGPUToSPIRV.cpp - Convert GPU ops to SPIR-V dialect ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the conversion patterns from GPU ops to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Module.h" + +using namespace mlir; + +namespace { + +/// Pattern to convert a loop::ForOp within kernel functions into spirv::LoopOp. +class ForOpConversion final : public SPIRVOpLowering<loop::ForOp> { +public: + using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation +/// builin variables. +template <typename SourceOp, spirv::BuiltIn builtin> +class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> { +public: + using SPIRVOpLowering<SourceOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(SourceOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to convert a kernel function in GPU dialect within a spv.module. +class KernelFnConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> { +public: + KernelFnConversion(MLIRContext *context, SPIRVTypeConverter &converter, + ArrayRef<int64_t> workGroupSize, + PatternBenefit benefit = 1) + : SPIRVOpLowering<gpu::GPUFuncOp>(context, converter, benefit) { + auto config = workGroupSize.take_front(3); + workGroupSizeAsInt32.assign(config.begin(), config.end()); + workGroupSizeAsInt32.resize(3, 1); + } + + PatternMatchResult + matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; + +private: + SmallVector<int32_t, 3> workGroupSizeAsInt32; +}; + +/// Pattern to convert a module with gpu.kernel_module attribute to a +/// spv.module. +class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> { +public: + using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to convert a module terminator op to a terminator of spv.module op. +// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined +// in ODS. +class KernelModuleTerminatorConversion final + : public SPIRVOpLowering<ModuleTerminatorOp> { +public: + using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to convert a gpu.return into a SPIR-V return. +// TODO: This can go to DRR when GPU return has operands. +class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> { +public: + using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// loop::ForOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + // loop::ForOp can be lowered to the structured control flow represented by + // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop + // latch and the merge block the exit block. The resulting spirv::LoopOp has a + // single back edge from the continue to header block, and a single exit from + // header to merge. + loop::ForOpOperandAdaptor forOperands(operands); + auto loc = forOp.getLoc(); + auto loopControl = rewriter.getI32IntegerAttr( + static_cast<uint32_t>(spirv::LoopControl::None)); + auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl); + loopOp.addEntryAndMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + // Create the block for the header. + auto header = new Block(); + // Insert the header. + loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); + + // Create the new induction variable to use. + BlockArgument newIndVar = + header->addArgument(forOperands.lowerBound()->getType()); + Block *body = forOp.getBody(); + + // Apply signature conversion to the body of the forOp. It has a single block, + // with argument which is the induction variable. That has to be replaced with + // the new induction variable. + TypeConverter::SignatureConversion signatureConverter( + body->getNumArguments()); + signatureConverter.remapInput(0, newIndVar); + body = rewriter.applySignatureConversion(&forOp.getLoopBody(), + signatureConverter); + + // Delete the loop terminator. + rewriter.eraseOp(body->getTerminator()); + + // Move the blocks from the forOp into the loopOp. This is the body of the + // loopOp. + rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), + std::next(loopOp.body().begin(), 2)); + + // Branch into it from the entry. + rewriter.setInsertionPointToEnd(&(loopOp.body().front())); + rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound()); + + // Generate the rest of the loop header. + rewriter.setInsertionPointToEnd(header); + auto mergeBlock = loopOp.getMergeBlock(); + auto cmpOp = rewriter.create<spirv::SLessThanOp>( + loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); + rewriter.create<spirv::BranchConditionalOp>( + loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); + + // Generate instructions to increment the step of the induction variable and + // branch to the header. + Block *continueBlock = loopOp.getContinueBlock(); + rewriter.setInsertionPointToEnd(continueBlock); + + // Add the step to the induction variable and branch to the header. + Value updatedIndVar = rewriter.create<spirv::IAddOp>( + loc, newIndVar->getType(), newIndVar, forOperands.step()); + rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); + + rewriter.eraseOp(forOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Builtins. +//===----------------------------------------------------------------------===// + +template <typename SourceOp, spirv::BuiltIn builtin> +PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( + SourceOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + auto dimAttr = + op.getOperation()->template getAttrOfType<StringAttr>("dimension"); + if (!dimAttr) { + return this->matchFailure(); + } + int32_t index = 0; + if (dimAttr.getValue() == "x") { + index = 0; + } else if (dimAttr.getValue() == "y") { + index = 1; + } else if (dimAttr.getValue() == "z") { + index = 2; + } else { + return this->matchFailure(); + } + + // SPIR-V invocation builtin variables are a vector of type <3xi32> + auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); + rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( + op, rewriter.getIntegerType(32), spirvBuiltin, + rewriter.getI32ArrayAttr({index})); + return this->matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// GPUFuncOp +//===----------------------------------------------------------------------===// + +// Legalizes a GPU function as an entry SPIR-V function. +static FuncOp +lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter, + spirv::EntryPointABIAttr entryPointInfo, + ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { + auto fnType = funcOp.getType(); + if (fnType.getNumResults()) { + funcOp.emitError("SPIR-V lowering only supports entry functions" + "with no return values right now"); + return nullptr; + } + if (fnType.getNumInputs() != argABIInfo.size()) { + funcOp.emitError( + "lowering as entry functions requires ABI info for all arguments"); + return nullptr; + } + // Update the signature to valid SPIR-V types and add the ABI + // attributes. These will be "materialized" by using the + // LowerABIAttributesPass. + TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); + { + for (auto argType : enumerate(funcOp.getType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + signatureConverter.addInputs(argType.index(), convertedType); + } + } + auto newFuncOp = rewriter.create<FuncOp>( + funcOp.getLoc(), funcOp.getName(), + rewriter.getFunctionType(signatureConverter.getConvertedTypes(), + llvm::None), + ArrayRef<NamedAttribute>()); + for (const auto &namedAttr : funcOp.getAttrs()) { + if (namedAttr.first.is(impl::getTypeAttrName()) || + namedAttr.first.is(SymbolTable::getSymbolAttrName())) + continue; + newFuncOp.setAttr(namedAttr.first, namedAttr.second); + } + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + rewriter.eraseOp(funcOp); + + spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo); + return newFuncOp; +} + +PatternMatchResult +KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp, + ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (!gpu::GPUDialect::isKernel(funcOp)) { + return matchFailure(); + } + + SmallVector<spirv::InterfaceVarABIAttr, 4> argABI; + for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { + argABI.push_back(spirv::getInterfaceVarABIAttr( + 0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext())); + } + + auto context = rewriter.getContext(); + auto entryPointAttr = + spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context); + FuncOp newFuncOp = lowerAsEntryFunction(funcOp, typeConverter, rewriter, + entryPointAttr, argABI); + if (!newFuncOp) { + return matchFailure(); + } + newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(), + rewriter.getContext())); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// ModuleOp with gpu.kernel_module. +//===----------------------------------------------------------------------===// + +PatternMatchResult KernelModuleConversion::matchAndRewrite( + ModuleOp moduleOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (!moduleOp.getAttrOfType<UnitAttr>( + gpu::GPUDialect::getKernelModuleAttrName())) { + return matchFailure(); + } + // TODO : Generalize this to account for different extensions, + // capabilities, extended_instruction_sets, other addressing models + // and memory models. + auto spvModule = rewriter.create<spirv::ModuleOp>( + moduleOp.getLoc(), spirv::AddressingModel::Logical, + spirv::MemoryModel::GLSL450, spirv::Capability::Shader, + spirv::Extension::SPV_KHR_storage_buffer_storage_class); + // Move the region from the module op into the SPIR-V module. + Region &spvModuleRegion = spvModule.getOperation()->getRegion(0); + rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion, + spvModuleRegion.begin()); + // The spv.module build method adds a block with a terminator. Remove that + // block. The terminator of the module op in the remaining block will be + // legalized later. + spvModuleRegion.back().erase(); + rewriter.eraseOp(moduleOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// ModuleTerminatorOp for gpu.kernel_module. +//===----------------------------------------------------------------------===// + +PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( + ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// GPU return inside kernel functions to SPIR-V return. +//===----------------------------------------------------------------------===// + +PatternMatchResult GPUReturnOpConversion::matchAndRewrite( + gpu::ReturnOp returnOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (!operands.empty()) + return matchFailure(); + + rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// GPU To SPIRV Patterns. +//===----------------------------------------------------------------------===// + +void mlir::populateGPUToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns, + ArrayRef<int64_t> workGroupSize) { + patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize); + patterns.insert< + GPUReturnOpConversion, ForOpConversion, KernelModuleConversion, + KernelModuleTerminatorConversion, + LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>, + LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>, + LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>, + LaunchConfigConversion<gpu::ThreadIdOp, + spirv::BuiltIn::LocalInvocationId>>(context, + typeConverter); +} diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp new file mode 100644 index 00000000000..68392c36765 --- /dev/null +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -0,0 +1,96 @@ +//===- ConvertGPUToSPIRVPass.cpp - GPU to SPIR-V dialect lowering passes --===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert a kernel function in the GPU Dialect +// into a spv.module operation +// +//===----------------------------------------------------------------------===// +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +using namespace mlir; + +namespace { +/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions +/// that have the "gpu.kernel" attribute, i.e. those functions that are +/// referenced in gpu::LaunchKernelOp operations. For each such function +/// +/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp +/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot +/// replace it). +/// +/// 2) Lower the body of the spirv::ModuleOp. +class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> { +public: + GPUToSPIRVPass() = default; + GPUToSPIRVPass(const GPUToSPIRVPass &) {} + GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) { + this->workGroupSize = workGroupSize; + } + + void runOnModule() override; + +private: + /// Command line option to specify the workgroup size. + ListOption<int64_t> workGroupSize{ + *this, "workgroup-size", + llvm::cl::desc( + "Workgroup Sizes in the SPIR-V module for x, followed by y, followed " + "by z dimension of the dispatch (others will be ignored)"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; +}; +} // namespace + +void GPUToSPIRVPass::runOnModule() { + auto context = &getContext(); + auto module = getModule(); + + SmallVector<Operation *, 1> kernelModules; + OpBuilder builder(context); + module.walk([&builder, &kernelModules](ModuleOp moduleOp) { + if (moduleOp.getAttrOfType<UnitAttr>( + gpu::GPUDialect::getKernelModuleAttrName())) { + // For each kernel module (should be only 1 for now, but that is not a + // requirement here), clone the module for conversion because the + // gpu.launch function still needs the kernel module. + builder.setInsertionPoint(moduleOp.getOperation()); + kernelModules.push_back(builder.clone(*moduleOp.getOperation())); + } + }); + + SPIRVTypeConverter typeConverter; + OwningRewritePatternList patterns; + populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize); + populateStandardToSPIRVPatterns(context, typeConverter, patterns); + + ConversionTarget target(*context); + target.addLegalDialect<spirv::SPIRVDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); + + if (failed(applyFullConversion(kernelModules, target, patterns, + &typeConverter))) { + return signalPassFailure(); + } +} + +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) { + return std::make_unique<GPUToSPIRVPass>(workGroupSize); +} + +static PassRegistration<GPUToSPIRVPass> + pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect"); diff --git a/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt b/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt new file mode 100644 index 00000000000..9d2b5dac202 --- /dev/null +++ b/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRLinalgToLLVM + LinalgToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToLLVM +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRLinalgToLLVM ${LIBS}) +target_link_libraries(MLIRLinalgToLLVM ${LIBS}) diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp new file mode 100644 index 00000000000..2a034fd15c5 --- /dev/null +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -0,0 +1,549 @@ +//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::LLVM; +using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; + +using add = ValueBuilder<mlir::LLVM::AddOp>; +using addi = ValueBuilder<mlir::AddIOp>; +using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>; +using cmpi = ValueBuilder<mlir::CmpIOp>; +using constant = ValueBuilder<mlir::LLVM::ConstantOp>; +using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>; +using gep = ValueBuilder<mlir::LLVM::GEPOp>; +using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>; +using llvm_call = OperationBuilder<mlir::LLVM::CallOp>; +using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; +using llvm_load = ValueBuilder<LLVM::LoadOp>; +using llvm_store = OperationBuilder<LLVM::StoreOp>; +using llvm_select = ValueBuilder<LLVM::SelectOp>; +using mul = ValueBuilder<mlir::LLVM::MulOp>; +using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>; +using sub = ValueBuilder<mlir::LLVM::SubOp>; +using llvm_undef = ValueBuilder<mlir::LLVM::UndefOp>; +using urem = ValueBuilder<mlir::LLVM::URemOp>; +using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; +using llvm_return = OperationBuilder<LLVM::ReturnOp>; + +template <typename T> +static LLVMType getPtrToElementType(T containerType, + LLVMTypeConverter &lowering) { + return lowering.convertType(containerType.getElementType()) + .template cast<LLVMType>() + .getPointerTo(); +} + +// Convert the given type to the LLVM IR Dialect type. The following +// conversions are supported: +// - an Index type is converted into an LLVM integer type with pointer +// bitwidth (analogous to intptr_t in C); +// - an Integer type is converted into an LLVM integer type of the same width; +// - an F32 type is converted into an LLVM float type +// - a Buffer, Range or View is converted into an LLVM structure type +// containing the respective dynamic values. +static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { + auto *context = t.getContext(); + auto int64Ty = lowering.convertType(IntegerType::get(64, context)) + .cast<LLVM::LLVMType>(); + + // Range descriptor contains the range bounds and the step as 64-bit integers. + // + // struct { + // int64_t min; + // int64_t max; + // int64_t step; + // }; + if (t.isa<RangeType>()) + return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); + + return Type(); +} + +namespace { +/// EDSC-compatible wrapper for MemRefDescriptor. +class BaseViewConversionHelper { +public: + BaseViewConversionHelper(Type type) + : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} + + BaseViewConversionHelper(Value v) : d(v) {} + + /// Wrappers around MemRefDescriptor that use EDSC builder and location. + Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } + void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } + Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } + Value offset() { return d.offset(rewriter(), loc()); } + void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } + Value size(unsigned i) { return d.size(rewriter(), loc(), i); } + void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } + Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } + void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } + + operator Value() { return d; } + +private: + OpBuilder &rewriter() { return ScopedContext::getBuilder(); } + Location loc() { return ScopedContext::getLocation(); } + + MemRefDescriptor d; +}; +} // namespace + +// RangeOp creates a new range descriptor. +class RangeOpConversion : public LLVMOpLowering { +public: + explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) + : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto rangeOp = cast<RangeOp>(op); + auto rangeDescriptorTy = + convertLinalgType(rangeOp.getResult()->getType(), lowering); + + edsc::ScopedContext context(rewriter, op->getLoc()); + + // Fill in an aggregate value of the descriptor. + RangeOpOperandAdaptor adaptor(operands); + Value desc = llvm_undef(rangeDescriptorTy); + desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); + desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); + desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +/// Conversion pattern that transforms a linalg.slice op into: +/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. +/// 2. A load of the ViewDescriptor from the pointer allocated in 1. +/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size +/// and stride corresponding to the region of memory within the bounds of +/// the parent view. +/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. +/// The linalg.slice op is replaced by the alloca'ed pointer. +class SliceOpConversion : public LLVMOpLowering { +public: + explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) + : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + edsc::ScopedContext context(rewriter, op->getLoc()); + SliceOpOperandAdaptor adaptor(operands); + BaseViewConversionHelper baseDesc(adaptor.view()); + + auto sliceOp = cast<SliceOp>(op); + auto memRefType = sliceOp.getBaseViewType(); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) + .cast<LLVM::LLVMType>(); + + BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); + + // TODO(ntv): extract sizes and emit asserts. + SmallVector<Value, 4> strides(memRefType.getRank()); + for (int i = 0, e = memRefType.getRank(); i < e; ++i) + strides[i] = baseDesc.stride(i); + + auto pos = [&rewriter](ArrayRef<int64_t> values) { + return rewriter.getI64ArrayAttr(values); + }; + + // Compute base offset. + Value baseOffset = baseDesc.offset(); + for (int i = 0, e = memRefType.getRank(); i < e; ++i) { + Value indexing = adaptor.indexings()[i]; + Value min = indexing; + if (sliceOp.indexing(i)->getType().isa<RangeType>()) + min = extractvalue(int64Ty, indexing, pos(0)); + baseOffset = add(baseOffset, mul(min, strides[i])); + } + + // Insert the base and aligned pointers. + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); + + // Insert base offset. + desc.setOffset(baseOffset); + + // Corner case, no sizes or strides: early return the descriptor. + if (sliceOp.getViewType().getRank() == 0) + return rewriter.replaceOp(op, {desc}), matchSuccess(); + + Value zero = + constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + // Compute and insert view sizes (max - min along the range) and strides. + // Skip the non-range operands as they will be projected away from the view. + int numNewDims = 0; + for (auto en : llvm::enumerate(sliceOp.indexings())) { + Value indexing = en.value(); + if (indexing->getType().isa<RangeType>()) { + int rank = en.index(); + Value rangeDescriptor = adaptor.indexings()[rank]; + Value min = extractvalue(int64Ty, rangeDescriptor, pos(0)); + Value max = extractvalue(int64Ty, rangeDescriptor, pos(1)); + Value step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value baseSize = baseDesc.size(rank); + + // Bound upper by base view upper bound. + max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, + baseSize); + Value size = sub(max, min); + // Bound lower by zero. + size = + llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); + Value stride = mul(strides[rank], step); + desc.setSize(numNewDims, size); + desc.setStride(numNewDims, stride); + ++numNewDims; + } + } + + rewriter.replaceOp(op, {desc}); + return matchSuccess(); + } +}; + +/// Conversion pattern that transforms a linalg.transpose op into: +/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. +/// 2. A load of the ViewDescriptor from the pointer allocated in 1. +/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size +/// and stride. Size and stride are permutations of the original values. +/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. +/// The linalg.transpose op is replaced by the alloca'ed pointer. +class TransposeOpConversion : public LLVMOpLowering { +public: + explicit TransposeOpConversion(MLIRContext *context, + LLVMTypeConverter &lowering_) + : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + // Initialize the common boilerplate and alloca at the top of the FuncOp. + edsc::ScopedContext context(rewriter, op->getLoc()); + TransposeOpOperandAdaptor adaptor(operands); + BaseViewConversionHelper baseDesc(adaptor.view()); + + auto transposeOp = cast<TransposeOp>(op); + // No permutation, early exit. + if (transposeOp.permutation().isIdentity()) + return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); + + BaseViewConversionHelper desc( + lowering.convertType(transposeOp.getViewType())); + + // Copy the base and aligned pointers from the old descriptor to the new + // one. + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); + + // Copy the offset pointer from the old descriptor to the new one. + desc.setOffset(baseDesc.offset()); + + // Iterate over the dimensions and apply size/stride permutation. + for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { + int sourcePos = en.index(); + int targetPos = en.value().cast<AffineDimExpr>().getPosition(); + desc.setSize(targetPos, baseDesc.size(sourcePos)); + desc.setStride(targetPos, baseDesc.stride(sourcePos)); + } + + rewriter.replaceOp(op, {desc}); + return matchSuccess(); + } +}; + +// YieldOp produces and LLVM::ReturnOp. +class YieldOpConversion : public LLVMOpLowering { +public: + explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) + : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); + return matchSuccess(); + } +}; + +template <typename LinalgOp> +static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) { + return SmallVector<Type, 4>{op->getOperandTypes()}; +} + +template <> +SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) { + auto ctx = op->getContext(); + auto indexedGenericOp = cast<IndexedGenericOp>(op); + auto numLoops = indexedGenericOp.getNumLoops(); + + SmallVector<Type, 4> result; + result.reserve(numLoops + op->getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + result.push_back(IndexType::get(ctx)); + } + for (auto type : op->getOperandTypes()) { + result.push_back(type); + } + return result; +} + +// Get a SymbolRefAttr containing the library function name for the LinalgOp. +// If the library function does not exist, insert a declaration. +template <typename LinalgOp> +static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, + PatternRewriter &rewriter) { + auto linalgOp = cast<LinalgOp>(op); + auto fnName = linalgOp.getLibraryCallName(); + if (fnName.empty()) { + op->emitWarning("No library call defined for: ") << *op; + return {}; + } + + // fnName is a dynamic std::String, unique it via a SymbolRefAttr. + FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); + auto module = op->getParentOfType<ModuleOp>(); + if (module.lookupSymbol(fnName)) { + return fnNameAttr; + } + + SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op)); + assert(op->getNumResults() == 0 && + "Library call for linalg operation can be generated only for ops that " + "have void return types"); + auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); + + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType, + ArrayRef<NamedAttribute>{}); + return fnNameAttr; +} + +Type LinalgTypeConverter::convertType(Type t) { + if (auto result = LLVMTypeConverter::convertType(t)) + return result; + return convertLinalgType(t, *this); +} + +// LinalgOpConversion<LinalgOp> creates a new call to the +// `LinalgOp::getLibraryCallName()` function. +// The implementation of the function can be either in the same module or in an +// externally linked library. +template <typename LinalgOp> +class LinalgOpConversion : public OpRewritePattern<LinalgOp> { +public: + using OpRewritePattern<LinalgOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter); + if (!libraryCallName) + return this->matchFailure(); + + rewriter.replaceOpWithNewOp<mlir::CallOp>( + op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands()); + return this->matchSuccess(); + } +}; + +/// Conversion pattern specialization for CopyOp. This kicks in when both input +/// and output permutations are left unspecified or are the identity. +template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> { +public: + using OpRewritePattern<CopyOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override { + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + return matchFailure(); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + return matchFailure(); + + auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter); + if (!libraryCallName) + return matchFailure(); + + rewriter.replaceOpWithNewOp<mlir::CallOp>( + op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands()); + return matchSuccess(); + } +}; + +/// Conversion pattern specialization for IndexedGenericOp. +template <> +class LinalgOpConversion<IndexedGenericOp> + : public OpRewritePattern<IndexedGenericOp> { +public: + using OpRewritePattern<IndexedGenericOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override { + auto libraryCallName = + getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter); + if (!libraryCallName) + return this->matchFailure(); + + // TODO(pifon, ntv): Use induction variables values instead of zeros, when + // IndexedGenericOp is tiled. + auto zero = rewriter.create<mlir::ConstantOp>( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto indexedGenericOp = cast<IndexedGenericOp>(op); + auto numLoops = indexedGenericOp.getNumLoops(); + SmallVector<Value, 4> operands; + operands.reserve(numLoops + op.getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + operands.push_back(zero); + } + for (auto operand : op.getOperands()) { + operands.push_back(operand); + } + rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(), + ArrayRef<Type>{}, operands); + return this->matchSuccess(); + } +}; + +/// A non-conversion rewrite pattern kicks in to convert CopyOp with +/// permutations into a sequence of TransposeOp and permutation-free CopyOp. +/// This interplays together with TransposeOpConversion and +/// LinalgConversion<CopyOp> to create a path to the LLVM dialect. +class CopyTransposeConversion : public OpRewritePattern<CopyOp> { +public: + using OpRewritePattern<CopyOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override { + Value in = op.input(), out = op.output(); + + // If either inputPerm or outputPerm are non-identities, insert transposes. + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in, + AffineMapAttr::get(*inputPerm)); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + out = rewriter.create<linalg::TransposeOp>( + op.getLoc(), out, AffineMapAttr::get(*outputPerm)); + + // If nothing was transposed, fail and let the conversion kick in. + if (in == op.input() && out == op.output()) + return matchFailure(); + + rewriter.replaceOpWithNewOp<CopyOp>(op, in, out); + return matchSuccess(); + } +}; + +/// Populate the given list with patterns that convert from Linalg to Standard. +static void +populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant + // attribute values such as kernel striding and dilation. + patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>, + LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>, + LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>, + LinalgOpConversion<IndexedGenericOp>, + LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>( + ctx); +} + +/// Populate the given list with patterns that convert from Linalg to LLVM. +void mlir::populateLinalgToLLVMConversionPatterns( + LinalgTypeConverter &converter, OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion, + YieldOpConversion>(ctx, converter); +} + +namespace { +struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> { + void runOnModule() override; +}; +} // namespace + +void ConvertLinalgToLLVMPass::runOnModule() { + auto module = getModule(); + + // Convert to the LLVM IR dialect using the converter defined above. + OwningRewritePatternList patterns; + LinalgTypeConverter converter(&getContext()); + populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + populateVectorToLLVMConversionPatterns(converter, patterns); + populateLinalgToStandardConversionPatterns(patterns, &getContext()); + populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); + if (failed(applyFullConversion(module, target, patterns, &converter))) + signalPassFailure(); +} + +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::linalg::createConvertLinalgToLLVMPass() { + return std::make_unique<ConvertLinalgToLLVMPass>(); +} + +static PassRegistration<ConvertLinalgToLLVMPass> pass( + "convert-linalg-to-llvm", + "Convert the operations from the linalg dialect into the LLVM dialect"); diff --git a/mlir/lib/Conversion/LoopToStandard/CMakeLists.txt b/mlir/lib/Conversion/LoopToStandard/CMakeLists.txt new file mode 100644 index 00000000000..8f05dbd0b63 --- /dev/null +++ b/mlir/lib/Conversion/LoopToStandard/CMakeLists.txt @@ -0,0 +1,22 @@ +add_llvm_library(MLIRLoopToStandard + ConvertLoopToStandard.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LoopToStandard +) +add_dependencies( + MLIRLoopToStandard + + MLIRLoopOps + MLIRTransforms + LLVMCore + LLVMSupport +) +target_link_libraries( + MLIRLoopToStandard + + MLIRLoopOps + MLIRTransforms + LLVMCore + LLVMSupport +) diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp new file mode 100644 index 00000000000..b257e9b482b --- /dev/null +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -0,0 +1,269 @@ +//===- ConvertLoopToStandard.cpp - ControlFlow to CFG conversion ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert loop.for, loop.if and loop.terminator +// ops into standard CFG ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +using namespace mlir; +using namespace mlir::loop; + +namespace { + +struct LoopToStandardPass : public OperationPass<LoopToStandardPass> { + void runOnOperation() override; +}; + +// Create a CFG subgraph for the loop around its body blocks (if the body +// contained other loops, they have been already lowered to a flow of blocks). +// Maintain the invariants that a CFG subgraph created for any loop has a single +// entry and a single exit, and that the entry/exit blocks are respectively +// first/last blocks in the parent region. The original loop operation is +// replaced by the initialization operations that set up the initial value of +// the loop induction variable (%iv) and computes the loop bounds that are loop- +// invariant for affine loops. The operations following the original loop.for +// are split out into a separate continuation (exit) block. A condition block is +// created before the continuation block. It checks the exit condition of the +// loop and branches either to the continuation block, or to the first block of +// the body. Induction variable modification is appended to the last block of +// the body (which is the exit block from the body subgraph thanks to the +// invariant we maintain) along with a branch that loops back to the condition +// block. +// +// +---------------------------------+ +// | <code before the ForOp> | +// | <compute initial %iv value> | +// | br cond(%iv) | +// +---------------------------------+ +// | +// -------| | +// | v v +// | +--------------------------------+ +// | | cond(%iv): | +// | | <compare %iv to upper bound> | +// | | cond_br %r, body, end | +// | +--------------------------------+ +// | | | +// | | -------------| +// | v | +// | +--------------------------------+ | +// | | body-first: | | +// | | <body contents> | | +// | +--------------------------------+ | +// | | | +// | ... | +// | | | +// | +--------------------------------+ | +// | | body-last: | | +// | | <body contents> | | +// | | %new_iv =<add step to %iv> | | +// | | br cond(%new_iv) | | +// | +--------------------------------+ | +// | | | +// |----------- |-------------------- +// v +// +--------------------------------+ +// | end: | +// | <code after the ForOp> | +// +--------------------------------+ +// +struct ForLowering : public OpRewritePattern<ForOp> { + using OpRewritePattern<ForOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override; +}; + +// Create a CFG subgraph for the loop.if operation (including its "then" and +// optional "else" operation blocks). We maintain the invariants that the +// subgraph has a single entry and a single exit point, and that the entry/exit +// blocks are respectively the first/last block of the enclosing region. The +// operations following the loop.if are split into a continuation (subgraph +// exit) block. The condition is lowered to a chain of blocks that implement the +// short-circuit scheme. Condition blocks are created by splitting out an empty +// block from the block that contains the loop.if operation. They +// conditionally branch to either the first block of the "then" region, or to +// the first block of the "else" region. If the latter is absent, they branch +// to the continuation block instead. The last blocks of "then" and "else" +// regions (which are known to be exit blocks thanks to the invariant we +// maintain). +// +// +--------------------------------+ +// | <code before the IfOp> | +// | cond_br %cond, %then, %else | +// +--------------------------------+ +// | | +// | --------------| +// v | +// +--------------------------------+ | +// | then: | | +// | <then contents> | | +// | br continue | | +// +--------------------------------+ | +// | | +// |---------- |------------- +// | V +// | +--------------------------------+ +// | | else: | +// | | <else contents> | +// | | br continue | +// | +--------------------------------+ +// | | +// ------| | +// v v +// +--------------------------------+ +// | continue: | +// | <code after the IfOp> | +// +--------------------------------+ +// +struct IfLowering : public OpRewritePattern<IfOp> { + using OpRewritePattern<IfOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IfOp ifOp, + PatternRewriter &rewriter) const override; +}; + +struct TerminatorLowering : public OpRewritePattern<TerminatorOp> { + using OpRewritePattern<TerminatorOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TerminatorOp op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return matchSuccess(); + } +}; +} // namespace + +PatternMatchResult +ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { + Location loc = forOp.getLoc(); + + // Start by splitting the block containing the 'loop.for' into two parts. + // The part before will get the init code, the part after will be the end + // point. + auto *initBlock = rewriter.getInsertionBlock(); + auto initPosition = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(initBlock, initPosition); + + // Use the first block of the loop body as the condition block since it is + // the block that has the induction variable as its argument. Split out + // all operations from the first block into a new block. Move all body + // blocks from the loop body region to the region containing the loop. + auto *conditionBlock = &forOp.region().front(); + auto *firstBodyBlock = + rewriter.splitBlock(conditionBlock, conditionBlock->begin()); + auto *lastBodyBlock = &forOp.region().back(); + rewriter.inlineRegionBefore(forOp.region(), endBlock); + auto iv = conditionBlock->getArgument(0); + + // Append the induction variable stepping logic to the last body block and + // branch back to the condition block. Construct an expression f : + // (x -> x+step) and apply this expression to the induction variable. + rewriter.setInsertionPointToEnd(lastBodyBlock); + auto step = forOp.step(); + auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult(); + if (!stepped) + return matchFailure(); + rewriter.create<BranchOp>(loc, conditionBlock, stepped); + + // Compute loop bounds before branching to the condition. + rewriter.setInsertionPointToEnd(initBlock); + Value lowerBound = forOp.lowerBound(); + Value upperBound = forOp.upperBound(); + if (!lowerBound || !upperBound) + return matchFailure(); + rewriter.create<BranchOp>(loc, conditionBlock, lowerBound); + + // With the body block done, we can fill in the condition block. + rewriter.setInsertionPointToEnd(conditionBlock); + auto comparison = + rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound); + + rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock, + ArrayRef<Value>(), endBlock, ArrayRef<Value>()); + // Ok, we're done! + rewriter.eraseOp(forOp); + return matchSuccess(); +} + +PatternMatchResult +IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { + auto loc = ifOp.getLoc(); + + // Start by splitting the block containing the 'loop.if' into two parts. + // The part before will contain the condition, the part after will be the + // continuation point. + auto *condBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + auto *continueBlock = rewriter.splitBlock(condBlock, opPosition); + + // Move blocks from the "then" region to the region containing 'loop.if', + // place it before the continuation block, and branch to it. + auto &thenRegion = ifOp.thenRegion(); + auto *thenBlock = &thenRegion.front(); + rewriter.setInsertionPointToEnd(&thenRegion.back()); + rewriter.create<BranchOp>(loc, continueBlock); + rewriter.inlineRegionBefore(thenRegion, continueBlock); + + // Move blocks from the "else" region (if present) to the region containing + // 'loop.if', place it before the continuation block and branch to it. It + // will be placed after the "then" regions. + auto *elseBlock = continueBlock; + auto &elseRegion = ifOp.elseRegion(); + if (!elseRegion.empty()) { + elseBlock = &elseRegion.front(); + rewriter.setInsertionPointToEnd(&elseRegion.back()); + rewriter.create<BranchOp>(loc, continueBlock); + rewriter.inlineRegionBefore(elseRegion, continueBlock); + } + + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock, + /*trueArgs=*/ArrayRef<Value>(), elseBlock, + /*falseArgs=*/ArrayRef<Value>()); + + // Ok, we're done! + rewriter.eraseOp(ifOp); + return matchSuccess(); +} + +void mlir::populateLoopToStdConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx); +} + +void LoopToStandardPass::runOnOperation() { + OwningRewritePatternList patterns; + populateLoopToStdConversionPatterns(patterns, &getContext()); + ConversionTarget target(getContext()); + target.addLegalDialect<StandardOpsDialect>(); + if (failed(applyPartialConversion(getOperation(), target, patterns))) + signalPassFailure(); +} + +std::unique_ptr<Pass> mlir::createLowerToCFGPass() { + return std::make_unique<LoopToStandardPass>(); +} + +static PassRegistration<LoopToStandardPass> + pass("convert-loop-to-std", "Convert Loop dialect to Standard dialect, " + "replacing structured control flow with a CFG"); diff --git a/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt b/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt new file mode 100644 index 00000000000..2dacc800cb2 --- /dev/null +++ b/mlir/lib/Conversion/LoopsToGPU/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LIBS + MLIRAffineOps + MLIRGPU + MLIRIR + MLIRLinalg + MLIRPass + MLIRStandardOps + MLIRSupport + MLIRTransforms + LLVMSupport +) + +add_llvm_library(MLIRLoopsToGPU + LoopsToGPU.cpp + LoopsToGPUPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LoopsToGPU +) +add_dependencies(MLIRLoopsToGPU ${LIBS}) +target_link_libraries(MLIRLoopsToGPU ${LIBS}) diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp new file mode 100644 index 00000000000..e500d10983c --- /dev/null +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -0,0 +1,528 @@ +//===- LoopsToGPU.cpp - Convert an affine loop nest to a GPU kernel -------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This implements a straightforward conversion of an loop nest into a GPU +// kernel. The caller is expected to guarantee that the conversion is correct +// or to further transform the kernel to ensure correctness. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "loops-to-gpu" + +using namespace mlir; +using namespace mlir::loop; + +using llvm::seq; + +// Extract an indexed value from KernelDim3. +static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { + switch (pos) { + case 0: + return dim3.x; + case 1: + return dim3.y; + case 2: + return dim3.z; + default: + llvm_unreachable("dim3 position out of bounds"); + } + return nullptr; +} + +// Get the lower bound-related operands of a loop operation. +static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) { + return forOp.getLowerBoundOperands(); +} +static SmallVector<Value, 1> getLowerBoundOperands(ForOp forOp) { + SmallVector<Value, 1> bounds(1, forOp.lowerBound()); + return bounds; +} + +// Get the upper bound-related operands of a loop operation. +static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { + return forOp.getUpperBoundOperands(); +} +static SmallVector<Value, 1> getUpperBoundOperands(ForOp forOp) { + SmallVector<Value, 1> bounds(1, forOp.upperBound()); + return bounds; +} + +// Get a Value that corresponds to the loop step. If the step is an attribute, +// materialize a corresponding constant using builder. +static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { + return builder.create<ConstantIndexOp>(forOp.getLoc(), forOp.getStep()); +} +static Value getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); } + +// Get a Value for the loop lower bound. If the value requires computation, +// materialize the instructions using builder. +static Value getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) { + return lowerAffineLowerBound(forOp, builder); +} +static Value getOrEmitLowerBound(ForOp forOp, OpBuilder &) { + return forOp.lowerBound(); +} + +// Get a Value for the loop upper bound. If the value requires computation, +// materialize the instructions using builder. +static Value getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) { + return lowerAffineUpperBound(forOp, builder); +} +static Value getOrEmitUpperBound(ForOp forOp, OpBuilder &) { + return forOp.upperBound(); +} + +// Check the structure of the loop nest: +// - there are enough loops to map to numDims; +// - the loops are perfectly nested; +// - the loop bounds can be computed above the outermost loop. +// This roughly corresponds to the "matcher" part of the pattern-based +// rewriting infrastructure. +template <typename OpTy> +LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) { + Region &limit = forOp.region(); + for (unsigned i = 0, e = numDims; i < e; ++i) { + Operation *nested = &forOp.getBody()->front(); + if (!areValuesDefinedAbove(getLowerBoundOperands(forOp), limit) || + !areValuesDefinedAbove(getUpperBoundOperands(forOp), limit)) + return forOp.emitError( + "loops with bounds depending on other mapped loops " + "are not supported"); + + // The innermost loop can have an arbitrary body, skip the perfect nesting + // check for it. + if (i == e - 1) + break; + + auto begin = forOp.getBody()->begin(), end = forOp.getBody()->end(); + if (forOp.getBody()->empty() || std::next(begin, 2) != end) + return forOp.emitError("expected perfectly nested loops in the body"); + + if (!(forOp = dyn_cast<OpTy>(nested))) + return nested->emitError("expected a nested loop"); + } + return success(); +} + +template <typename OpTy> +LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims, + unsigned numThreadDims) { + if (numBlockDims < 1 || numThreadDims < 1) { + LLVM_DEBUG(llvm::dbgs() << "nothing to map"); + return success(); + } + + OpBuilder builder(forOp.getOperation()); + if (numBlockDims > 3) { + return forOp.emitError("cannot map to more than 3 block dimensions"); + } + if (numThreadDims > 3) { + return forOp.emitError("cannot map to more than 3 thread dimensions"); + } + return checkLoopNestMappableImpl(forOp, numBlockDims + numThreadDims); +} + +template <typename OpTy> +LogicalResult checkLoopOpMappable(OpTy forOp, unsigned numBlockDims, + unsigned numThreadDims) { + if (numBlockDims < 1 || numThreadDims < 1) { + LLVM_DEBUG(llvm::dbgs() << "nothing to map"); + return success(); + } + + if (numBlockDims > 3) { + return forOp.emitError("cannot map to more than 3 block dimensions"); + } + if (numThreadDims > 3) { + return forOp.emitError("cannot map to more than 3 thread dimensions"); + } + if (numBlockDims != numThreadDims) { + // TODO(ravishankarm) : This can probably be relaxed by having a one-trip + // loop for the missing dimension, but there is not reason to handle this + // case for now. + return forOp.emitError( + "mismatch in block dimensions and thread dimensions"); + } + + // Check that the forOp contains perfectly nested loops for numBlockDims + if (failed(checkLoopNestMappableImpl(forOp, numBlockDims))) { + return failure(); + } + + // Get to the innermost loop. + for (auto i : seq<unsigned>(0, numBlockDims - 1)) { + forOp = cast<OpTy>(&forOp.getBody()->front()); + (void)i; + } + + // The forOp now points to the body of the innermost loop mapped to blocks. + for (Operation &op : *forOp.getBody()) { + // If the operation is a loop, check that it is mappable to workItems. + if (auto innerLoop = dyn_cast<OpTy>(&op)) { + if (failed(checkLoopNestMappableImpl(innerLoop, numThreadDims))) { + return failure(); + } + continue; + } + // TODO(ravishankarm) : If it is not a loop op, it is assumed that the + // statement is executed by all threads. It might be a collective operation, + // or some non-side effect instruction. Have to decide on "allowable" + // statements and check for those here. + } + return success(); +} + +namespace { +// Helper structure that holds common state of the loop to GPU kernel +// conversion. +struct LoopToGpuConverter { + template <typename OpTy> + Optional<OpTy> collectBounds(OpTy forOp, unsigned numLoops); + + template <typename OpTy> + void createLaunch(OpTy rootForOp, OpTy innermostForOp, unsigned numBlockDims, + unsigned numThreadDims); + + // Ranges of the loops mapped to blocks or threads. + SmallVector<Value, 6> dims; + // Lower bounds of the loops mapped to blocks or threads. + SmallVector<Value, 6> lbs; + // Induction variables of the loops mapped to blocks or threads. + SmallVector<Value, 6> ivs; + // Steps of the loops mapped to blocks or threads. + SmallVector<Value, 6> steps; +}; +} // namespace + +// Return true if the value is obviously a constant "one". +static bool isConstantOne(Value value) { + if (auto def = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp())) + return def.getValue() == 1; + return false; +} + +// Collect ranges, bounds, steps and induction variables in preparation for +// mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel. +// This may fail if the IR for computing loop bounds cannot be constructed, for +// example if an affine loop uses semi-affine maps. Return the last loop to be +// mapped on success, llvm::None on failure. +template <typename OpTy> +Optional<OpTy> LoopToGpuConverter::collectBounds(OpTy forOp, + unsigned numLoops) { + OpBuilder builder(forOp.getOperation()); + dims.reserve(numLoops); + lbs.reserve(numLoops); + ivs.reserve(numLoops); + steps.reserve(numLoops); + OpTy currentLoop = forOp; + for (unsigned i = 0; i < numLoops; ++i) { + Value lowerBound = getOrEmitLowerBound(currentLoop, builder); + Value upperBound = getOrEmitUpperBound(currentLoop, builder); + if (!lowerBound || !upperBound) { + return llvm::None; + } + + Value range = + builder.create<SubIOp>(currentLoop.getLoc(), upperBound, lowerBound); + Value step = getOrCreateStep(currentLoop, builder); + if (!isConstantOne(step)) + range = builder.create<SignedDivIOp>(currentLoop.getLoc(), range, step); + dims.push_back(range); + + lbs.push_back(lowerBound); + ivs.push_back(currentLoop.getInductionVar()); + steps.push_back(step); + + if (i != numLoops - 1) + currentLoop = cast<OpTy>(¤tLoop.getBody()->front()); + } + return currentLoop; +} + +/// Given `nDims` perfectly nested loops rooted as `rootForOp`, convert them o +/// be partitioned across workgroups or workitems. The values for the +/// workgroup/workitem id along each dimension is passed in with `ids`. The +/// number of workgroups/workitems along each dimension are passed in with +/// `nids`. The innermost loop is mapped to the x-dimension, followed by the +/// next innermost loop to y-dimension, followed by z-dimension. +template <typename OpTy> +OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value> ids, + ArrayRef<Value> nids) { + auto nDims = ids.size(); + assert(nDims == nids.size()); + for (auto dim : llvm::seq<unsigned>(0, nDims)) { + // TODO(ravishankarm): Don't always need to generate a loop here. If nids >= + // number of iterations of the original loop, this becomes a if + // condition. Though that does rely on how the workgroup/workitem sizes are + // specified to begin with. + mapLoopToProcessorIds(rootForOp, ids[dim], nids[dim]); + if (dim != nDims - 1) { + rootForOp = cast<OpTy>(rootForOp.getBody()->front()); + } + } + return rootForOp; +} + +/// Utility method to convert the gpu::KernelDim3 object for representing id of +/// each workgroup/workitem and number of workgroup/workitems along a dimension +/// of the launch into a container. +void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids, + unsigned nDims, SmallVectorImpl<Value> &ids, + SmallVectorImpl<Value> &nids) { + assert(nDims <= 3 && "invalid number of launch dimensions"); + SmallVector<Value, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x}; + SmallVector<Value, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x}; + ids.clear(); + ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end()); + nids.clear(); + nids.append(std::next(allNids.begin(), allNids.size() - nDims), + allNids.end()); +} + +/// Generate the body of the launch operation. +template <typename OpTy> +LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp, + gpu::LaunchOp launchOp, unsigned numBlockDims, + unsigned numThreadDims) { + OpBuilder::InsertionGuard bodyInsertionGuard(builder); + builder.setInsertionPointToEnd(&launchOp.body().front()); + auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc()); + + rootForOp.getOperation()->moveBefore(returnOp); + SmallVector<Value, 3> workgroupID, numWorkGroups; + packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims, + workgroupID, numWorkGroups); + + // Partition the loop for mapping to workgroups. + auto loopOp = createGPULaunchLoops(rootForOp, workgroupID, numWorkGroups); + + // Iterate over the body of the loopOp and get the loops to partition for + // thread blocks. + SmallVector<OpTy, 1> threadRootForOps; + for (Operation &op : *loopOp.getBody()) { + if (auto threadRootForOp = dyn_cast<OpTy>(&op)) { + threadRootForOps.push_back(threadRootForOp); + } + } + + SmallVector<Value, 3> workItemID, workGroupSize; + packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(), + numThreadDims, workItemID, workGroupSize); + for (auto &loopOp : threadRootForOps) { + builder.setInsertionPoint(loopOp); + createGPULaunchLoops(loopOp, workItemID, workGroupSize); + } + return success(); +} + +// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the +// given workgroup size and number of workgroups. +template <typename OpTy> +LogicalResult createLaunchFromOp(OpTy rootForOp, ArrayRef<Value> numWorkGroups, + ArrayRef<Value> workGroupSizes) { + OpBuilder builder(rootForOp.getOperation()); + if (numWorkGroups.size() > 3) { + return rootForOp.emitError("invalid ") + << numWorkGroups.size() << "-D workgroup specification"; + } + auto loc = rootForOp.getLoc(); + Value one = builder.create<ConstantOp>( + loc, builder.getIntegerAttr(builder.getIndexType(), 1)); + SmallVector<Value, 3> numWorkGroups3D(3, one), workGroupSize3D(3, one); + for (auto numWorkGroup : enumerate(numWorkGroups)) { + numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value(); + } + for (auto workGroupSize : enumerate(workGroupSizes)) { + workGroupSize3D[workGroupSize.index()] = workGroupSize.value(); + } + + // Get the values used within the region of the rootForOp but defined above + // it. + llvm::SetVector<Value> valuesToForwardSet; + getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(), + valuesToForwardSet); + // Also add the values used for the lb, ub, and step of the rootForOp. + valuesToForwardSet.insert(rootForOp.getOperands().begin(), + rootForOp.getOperands().end()); + auto valuesToForward = valuesToForwardSet.takeVector(); + auto launchOp = builder.create<gpu::LaunchOp>( + rootForOp.getLoc(), numWorkGroups3D[0], numWorkGroups3D[1], + numWorkGroups3D[2], workGroupSize3D[0], workGroupSize3D[1], + workGroupSize3D[2], valuesToForward); + if (failed(createLaunchBody(builder, rootForOp, launchOp, + numWorkGroups.size(), workGroupSizes.size()))) { + return failure(); + } + + // Replace values that are used within the region of the launchOp but are + // defined outside. They all are replaced with kernel arguments. + for (const auto &pair : + llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { + Value from = std::get<0>(pair); + Value to = std::get<1>(pair); + replaceAllUsesInRegionWith(from, to, launchOp.body()); + } + return success(); +} + +// Replace the rooted at "rootForOp" with a GPU launch operation. This expects +// "innermostForOp" to point to the last loop to be transformed to the kernel, +// and to have (numBlockDims + numThreadDims) perfectly nested loops between +// "rootForOp" and "innermostForOp". +// TODO(ravishankarm) : This method can be modified to use the +// createLaunchFromOp method, since that is a strict generalization of this +// method. +template <typename OpTy> +void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp, + unsigned numBlockDims, + unsigned numThreadDims) { + OpBuilder builder(rootForOp.getOperation()); + // Prepare the grid and block sizes for the launch operation. If there is + // no loop mapped to a specific dimension, use constant "1" as its size. + Value constOne = (numBlockDims < 3 || numThreadDims < 3) + ? builder.create<ConstantIndexOp>(rootForOp.getLoc(), 1) + : nullptr; + Value gridSizeX = dims[0]; + Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; + Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; + Value blockSizeX = dims[numBlockDims]; + Value blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne; + Value blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne; + + // Create a launch op and move the body region of the innermost loop to the + // launch op. Pass the values defined outside the outermost loop and used + // inside the innermost loop and loop lower bounds as kernel data arguments. + // Still assuming perfect nesting so there are no values other than induction + // variables that are defined in one loop and used in deeper loops. + llvm::SetVector<Value> valuesToForwardSet; + getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(), + valuesToForwardSet); + auto valuesToForward = valuesToForwardSet.takeVector(); + auto originallyForwardedValues = valuesToForward.size(); + valuesToForward.insert(valuesToForward.end(), lbs.begin(), lbs.end()); + valuesToForward.insert(valuesToForward.end(), steps.begin(), steps.end()); + auto launchOp = builder.create<gpu::LaunchOp>( + rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, + blockSizeY, blockSizeZ, valuesToForward); + valuesToForward.resize(originallyForwardedValues); + + // Replace the loop terminator (loops contain only a single block) with the + // gpu return and move the operations from the loop body block to the gpu + // launch body block. Do not move the entire block because of the difference + // in block arguments. + Operation &terminator = innermostForOp.getBody()->back(); + Location terminatorLoc = terminator.getLoc(); + terminator.erase(); + builder.setInsertionPointToEnd(innermostForOp.getBody()); + builder.create<gpu::ReturnOp>(terminatorLoc); + launchOp.body().front().getOperations().splice( + launchOp.body().front().begin(), + innermostForOp.getBody()->getOperations()); + + // Remap the loop iterators to use block/thread identifiers instead. Loops + // may iterate from LB with step S whereas GPU thread/block ids always iterate + // from 0 to N with step 1. Therefore, loop induction variables are replaced + // with (gpu-thread/block-id * S) + LB. + builder.setInsertionPointToStart(&launchOp.body().front()); + auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(), + originallyForwardedValues); + auto stepArgumentIt = std::next(lbArgumentIt, lbs.size()); + for (auto en : llvm::enumerate(ivs)) { + Value id = + en.index() < numBlockDims + ? getDim3Value(launchOp.getBlockIds(), en.index()) + : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); + Value step = steps[en.index()]; + if (!isConstantOne(step)) + id = builder.create<MulIOp>(rootForOp.getLoc(), step, id); + + Value ivReplacement = + builder.create<AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id); + en.value()->replaceAllUsesWith(ivReplacement); + replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt, + launchOp.body()); + std::advance(lbArgumentIt, 1); + std::advance(stepArgumentIt, 1); + } + + // Remap the values defined outside the body to use kernel arguments instead. + // The list of kernel arguments also contains the lower bounds for loops at + // trailing positions, make sure we don't touch those. + for (const auto &pair : + llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) { + Value from = std::get<0>(pair); + Value to = std::get<1>(pair); + replaceAllUsesInRegionWith(from, to, launchOp.body()); + } + + // We are done and can erase the original outermost loop. + rootForOp.erase(); +} + +// Generic loop to GPU kernel conversion function. +template <typename OpTy> +static LogicalResult convertLoopNestToGPULaunch(OpTy forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + if (failed(checkLoopNestMappable(forOp, numBlockDims, numThreadDims))) + return failure(); + + LoopToGpuConverter converter; + auto maybeInnerLoop = + converter.collectBounds(forOp, numBlockDims + numThreadDims); + if (!maybeInnerLoop) + return failure(); + converter.createLaunch(forOp, *maybeInnerLoop, numBlockDims, numThreadDims); + + return success(); +} + +// Generic loop to GPU kernel conversion function when loop is imperfectly +// nested. The workgroup size and num workgroups is provided as input +template <typename OpTy> +static LogicalResult convertLoopToGPULaunch(OpTy forOp, + ArrayRef<Value> numWorkGroups, + ArrayRef<Value> workGroupSize) { + if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(), + workGroupSize.size()))) { + return failure(); + } + return createLaunchFromOp(forOp, numWorkGroups, workGroupSize); +} + +LogicalResult mlir::convertAffineLoopNestToGPULaunch(AffineForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); +} + +LogicalResult mlir::convertLoopNestToGPULaunch(ForOp forOp, + unsigned numBlockDims, + unsigned numThreadDims) { + return ::convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims); +} + +LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp, + ArrayRef<Value> numWorkGroups, + ArrayRef<Value> workGroupSizes) { + return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes); +} diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp new file mode 100644 index 00000000000..c3bbf274818 --- /dev/null +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -0,0 +1,147 @@ +//===- LoopsToGPUPass.cpp - Convert a loop nest to a GPU kernel -----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" +#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/CommandLine.h" + +#define PASS_NAME "convert-loops-to-gpu" +#define LOOPOP_TO_GPU_PASS_NAME "convert-loop-op-to-gpu" + +using namespace mlir; +using namespace mlir::loop; + +static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options"); +static llvm::cl::opt<unsigned> + clNumBlockDims("gpu-block-dims", + llvm::cl::desc("Number of GPU block dimensions for mapping"), + llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u)); +static llvm::cl::opt<unsigned> clNumThreadDims( + "gpu-thread-dims", + llvm::cl::desc("Number of GPU thread dimensions for mapping"), + llvm::cl::cat(clOptionsCategory), llvm::cl::init(1u)); + +static llvm::cl::OptionCategory clLoopOpToGPUCategory(LOOPOP_TO_GPU_PASS_NAME + " options"); +static llvm::cl::list<unsigned> + clNumWorkGroups("gpu-num-workgroups", + llvm::cl::desc("Num workgroups in the GPU launch"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::cat(clLoopOpToGPUCategory)); +static llvm::cl::list<unsigned> + clWorkGroupSize("gpu-workgroup-size", + llvm::cl::desc("Workgroup Size in the GPU launch"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::cat(clLoopOpToGPUCategory)); + +namespace { +// A pass that traverses top-level loops in the function and converts them to +// GPU launch operations. Nested launches are not allowed, so this does not +// walk the function recursively to avoid considering nested loops. +struct ForLoopMapper : public FunctionPass<ForLoopMapper> { + ForLoopMapper(unsigned numBlockDims, unsigned numThreadDims) + : numBlockDims(numBlockDims), numThreadDims(numThreadDims) {} + + void runOnFunction() override { + for (Block &block : getFunction()) + for (Operation &op : llvm::make_early_inc_range(block)) { + if (auto forOp = dyn_cast<AffineForOp>(&op)) { + if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims, + numThreadDims))) + signalPassFailure(); + } else if (auto forOp = dyn_cast<ForOp>(&op)) { + if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims, + numThreadDims))) + signalPassFailure(); + } + } + } + + unsigned numBlockDims; + unsigned numThreadDims; +}; + +// A pass that traverses top-level loops in the function and convertes them to +// GPU launch operations. The top-level loops itself does not have to be +// perfectly nested. The only requirement is that there be as many perfectly +// nested loops as the size of `numWorkGroups`. Within these any loop nest has +// to be perfectly nested upto depth equal to size of `workGroupSize`. +struct ImperfectlyNestedForLoopMapper + : public FunctionPass<ImperfectlyNestedForLoopMapper> { + ImperfectlyNestedForLoopMapper(ArrayRef<int64_t> numWorkGroups, + ArrayRef<int64_t> workGroupSize) + : numWorkGroups(numWorkGroups.begin(), numWorkGroups.end()), + workGroupSize(workGroupSize.begin(), workGroupSize.end()) {} + + void runOnFunction() override { + // Insert the num work groups and workgroup sizes as constant values. This + // pass is only used for testing. + FuncOp funcOp = getFunction(); + OpBuilder builder(funcOp.getOperation()->getRegion(0)); + SmallVector<Value, 3> numWorkGroupsVal, workGroupSizeVal; + for (auto val : numWorkGroups) { + auto constOp = builder.create<ConstantOp>( + funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); + numWorkGroupsVal.push_back(constOp); + } + for (auto val : workGroupSize) { + auto constOp = builder.create<ConstantOp>( + funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); + workGroupSizeVal.push_back(constOp); + } + for (Block &block : getFunction()) { + for (Operation &op : llvm::make_early_inc_range(block)) { + if (auto forOp = dyn_cast<ForOp>(&op)) { + if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal, + workGroupSizeVal))) { + return signalPassFailure(); + } + } + } + } + } + SmallVector<int64_t, 3> numWorkGroups; + SmallVector<int64_t, 3> workGroupSize; +}; + +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createSimpleLoopsToGPUPass(unsigned numBlockDims, + unsigned numThreadDims) { + return std::make_unique<ForLoopMapper>(numBlockDims, numThreadDims); +} + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups, + ArrayRef<int64_t> workGroupSize) { + return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups, + workGroupSize); +} + +static PassRegistration<ForLoopMapper> + registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] { + return std::make_unique<ForLoopMapper>(clNumBlockDims.getValue(), + clNumThreadDims.getValue()); + }); + +static PassRegistration<ImperfectlyNestedForLoopMapper> loopOpToGPU( + LOOPOP_TO_GPU_PASS_NAME, "Convert top-level loop::ForOp to GPU kernels", + [] { + SmallVector<int64_t, 3> numWorkGroups, workGroupSize; + numWorkGroups.assign(clNumWorkGroups.begin(), clNumWorkGroups.end()); + workGroupSize.assign(clWorkGroupSize.begin(), clWorkGroupSize.end()); + return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups, + workGroupSize); + }); diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt new file mode 100644 index 00000000000..6334c273493 --- /dev/null +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -0,0 +1,24 @@ +add_llvm_library(MLIRStandardToLLVM + ConvertStandardToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/StandardToLLVM +) +add_dependencies( + MLIRStandardToLLVM + + MLIRLoopToStandard + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport +) +target_link_libraries( + MLIRStandardToLLVM + + MLIRLoopToStandard + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport +) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp new file mode 100644 index 00000000000..0c96cc5e9c7 --- /dev/null +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -0,0 +1,2278 @@ +//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard and builtin dialects +// into the LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" + +using namespace mlir; + +#define PASS_NAME "convert-std-to-llvm" + +static llvm::cl::OptionCategory + clOptionsCategory("Standard to LLVM lowering options"); + +static llvm::cl::opt<bool> + clUseAlloca(PASS_NAME "-use-alloca", + llvm::cl::desc("Replace emission of malloc/free by alloca"), + llvm::cl::init(false)); + +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) + : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) { + assert(llvmDialect && "LLVM IR dialect is not registered"); + module = &llvmDialect->getLLVMModule(); +} + +// Get the LLVM context. +llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { + return module->getContext(); +} + +// Extract an LLVM IR type from the LLVM IR dialect type. +LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; +} + +LLVM::LLVMType LLVMTypeConverter::getIndexType() { + return LLVM::LLVMType::getIntNTy( + llvmDialect, module->getDataLayout().getPointerSizeInBits()); +} + +Type LLVMTypeConverter::convertIndexType(IndexType type) { + return getIndexType(); +} + +Type LLVMTypeConverter::convertIntegerType(IntegerType type) { + return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); +} + +Type LLVMTypeConverter::convertFloatType(FloatType type) { + switch (type.getKind()) { + case mlir::StandardTypes::F32: + return LLVM::LLVMType::getFloatTy(llvmDialect); + case mlir::StandardTypes::F64: + return LLVM::LLVMType::getDoubleTy(llvmDialect); + case mlir::StandardTypes::F16: + return LLVM::LLVMType::getHalfTy(llvmDialect); + case mlir::StandardTypes::BF16: { + auto *mlirContext = llvmDialect->getContext(); + return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), + Type(); + } + default: + llvm_unreachable("non-float type in convertFloatType"); + } +} + +// Except for signatures, MLIR function types are converted into LLVM +// pointer-to-function types. +Type LLVMTypeConverter::convertFunctionType(FunctionType type) { + SignatureConversion conversion(type.getNumInputs()); + LLVM::LLVMType converted = + convertFunctionSignature(type, /*isVariadic=*/false, conversion); + return converted.getPointerTo(); +} + +// Function types are converted to LLVM Function types by recursively converting +// argument and result types. If MLIR Function has zero results, the LLVM +// Function has one VoidType result. If MLIR Function has more than one result, +// they are into an LLVM StructType in their order of appearance. +LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( + FunctionType type, bool isVariadic, + LLVMTypeConverter::SignatureConversion &result) { + // Convert argument types one by one and check for errors. + for (auto &en : llvm::enumerate(type.getInputs())) { + Type type = en.value(); + auto converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>(); + if (!converted) + return {}; + if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) + converted = converted.getPointerTo(); + result.addInputs(en.index(), converted); + } + + SmallVector<LLVM::LLVMType, 8> argTypes; + argTypes.reserve(llvm::size(result.getConvertedTypes())); + for (Type type : result.getConvertedTypes()) + argTypes.push_back(unwrap(type)); + + // If function does not return anything, create the void result type, + // if it returns on element, convert it, otherwise pack the result types into + // a struct. + LLVM::LLVMType resultType = + type.getNumResults() == 0 + ? LLVM::LLVMType::getVoidTy(llvmDialect) + : unwrap(packFunctionResults(type.getResults())); + if (!resultType) + return {}; + return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); +} + +// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which +// contains: +// 1. the pointer to the data buffer, followed by +// 2. a lowered `index`-type integer containing the distance between the +// beginning of the buffer and the first element to be accessed through the +// view, followed by +// 3. an array containing as many `index`-type integers as the rank of the +// MemRef: the array represents the size, in number of elements, of the memref +// along the given dimension. For constant MemRef dimensions, the +// corresponding size entry is a constant whose runtime value must match the +// static value, followed by +// 4. a second array containing as many `index`-type integers as the rank of +// the MemRef: the second array represents the "stride" (in tensor abstraction +// sense), i.e. the number of consecutive elements of the underlying buffer. +// TODO(ntv, zinenko): add assertions for the static cases. +// +// template <typename Elem, size_t Rank> +// struct { +// Elem *allocatedPtr; +// Elem *alignedPtr; +// int64_t offset; +// int64_t sizes[Rank]; // omitted when rank == 0 +// int64_t strides[Rank]; // omitted when rank == 0 +// }; +static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; +static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; +static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; +static constexpr unsigned kSizePosInMemRefDescriptor = 3; +static constexpr unsigned kStridePosInMemRefDescriptor = 4; +Type LLVMTypeConverter::convertMemRefType(MemRefType type) { + int64_t offset; + SmallVector<int64_t, 4> strides; + bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); + assert(strideSuccess && + "Non-strided layout maps must have been normalized away"); + (void)strideSuccess; + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); + auto indexTy = getIndexType(); + auto rank = type.getRank(); + if (rank > 0) { + auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); + return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy); + } + return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); +} + +// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which +// contains: +// 1. int64_t rank, the dynamic rank of this MemRef +// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be +// stack allocated (alloca) copy of a MemRef descriptor that got casted to +// be unranked. + +static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; +static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; + +Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { + auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + return LLVM::LLVMType::getStructTy(rankTy, ptrTy); +} + +// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when +// n > 1. +// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and +// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. +Type LLVMTypeConverter::convertVectorType(VectorType type) { + auto elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto vectorType = + LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); + auto shape = type.getShape(); + for (int i = shape.size() - 2; i >= 0; --i) + vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); + return vectorType; +} + +// Dispatch based on the actual type. Return null type on error. +Type LLVMTypeConverter::convertStandardType(Type t) { + return TypeSwitch<Type, Type>(t) + .Case([&](FloatType type) { return convertFloatType(type); }) + .Case([&](FunctionType type) { return convertFunctionType(type); }) + .Case([&](IndexType type) { return convertIndexType(type); }) + .Case([&](IntegerType type) { return convertIntegerType(type); }) + .Case([&](MemRefType type) { return convertMemRefType(type); }) + .Case([&](UnrankedMemRefType type) { + return convertUnrankedMemRefType(type); + }) + .Case([&](VectorType type) { return convertVectorType(type); }) + .Case([](LLVM::LLVMType type) { return type; }) + .Default([](Type) { return Type(); }); +} + +LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, + LLVMTypeConverter &lowering_, + PatternBenefit benefit) + : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} + +/*============================================================================*/ +/* StructBuilder implementation */ +/*============================================================================*/ +StructBuilder::StructBuilder(Value v) : value(v) { + assert(value != nullptr && "value cannot be null"); + structType = value->getType().cast<LLVM::LLVMType>(); +} + +Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { + Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos); + return builder.create<LLVM::ExtractValueOp>(loc, type, value, + builder.getI64ArrayAttr(pos)); +} + +void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, + Value ptr) { + value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, + builder.getI64ArrayAttr(pos)); +} +/*============================================================================*/ +/* MemRefDescriptor implementation */ +/*============================================================================*/ + +/// Construct a helper for the given descriptor value. +MemRefDescriptor::MemRefDescriptor(Value descriptor) + : StructBuilder(descriptor) { + assert(value != nullptr && "value cannot be null"); + indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( + kOffsetPosInMemRefDescriptor); +} + +/// Builds IR creating an `undef` value of the descriptor type. +MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, + Type descriptorType) { + + Value descriptor = + builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); + return MemRefDescriptor(descriptor); +} + +/// Builds IR creating a MemRef descriptor that represents `type` and +/// populates it with static shape and stride information extracted from the +/// type. +MemRefDescriptor +MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + MemRefType type, Value memory) { + assert(type.hasStaticShape() && "unexpected dynamic shape"); + assert(type.getAffineMaps().empty() && "unexpected layout map"); + + auto convertedType = typeConverter.convertType(type); + assert(convertedType && "unexpected failure in memref type conversion"); + + auto descr = MemRefDescriptor::undef(builder, loc, convertedType); + descr.setAllocatedPtr(builder, loc, memory); + descr.setAlignedPtr(builder, loc, memory); + descr.setConstantOffset(builder, loc, 0); + + // Fill in sizes and strides, in reverse order to simplify stride + // calculation. + uint64_t runningStride = 1; + for (unsigned i = type.getRank(); i > 0; --i) { + unsigned dim = i - 1; + descr.setConstantSize(builder, loc, dim, type.getDimSize(dim)); + descr.setConstantStride(builder, loc, dim, runningStride); + runningStride *= type.getDimSize(dim); + } + return descr; +} + +/// Builds IR extracting the allocated pointer from the descriptor. +Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); +} + +/// Builds IR inserting the allocated pointer into the descriptor. +void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value ptr) { + setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); +} + +/// Builds IR extracting the aligned pointer from the descriptor. +Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); +} + +/// Builds IR inserting the aligned pointer into the descriptor. +void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + Value ptr) { + setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); +} + +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create<LLVM::ConstantOp>( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + +/// Builds IR extracting the offset from the descriptor. +Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); +} + +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + Value offset) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, offset, + builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); +} + +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, + uint64_t offset) { + setOffset(builder, loc, + createIndexAttrConstant(builder, loc, indexType, offset)); +} + +/// Builds IR extracting the pos-th size from the descriptor. +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); +} + +/// Builds IR inserting the pos-th size into the descriptor +void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, + Value size) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, size, + builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); +} + +/// Builds IR inserting the pos-th size into the descriptor +void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, + unsigned pos, uint64_t size) { + setSize(builder, loc, pos, + createIndexAttrConstant(builder, loc, indexType, size)); +} + +/// Builds IR extracting the pos-th size from the descriptor. +Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); +} + +/// Builds IR inserting the pos-th stride into the descriptor +void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, + Value stride) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, stride, + builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); +} + +/// Builds IR inserting the pos-th stride into the descriptor +void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, + unsigned pos, uint64_t stride) { + setStride(builder, loc, pos, + createIndexAttrConstant(builder, loc, indexType, stride)); +} + +LLVM::LLVMType MemRefDescriptor::getElementType() { + return value->getType().cast<LLVM::LLVMType>().getStructElementType( + kAlignedPtrPosInMemRefDescriptor); +} + +/*============================================================================*/ +/* UnrankedMemRefDescriptor implementation */ +/*============================================================================*/ + +/// Construct a helper for the given descriptor value. +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) + : StructBuilder(descriptor) {} + +/// Builds IR creating an `undef` value of the descriptor type. +UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, + Location loc, + Type descriptorType) { + Value descriptor = + builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); + return UnrankedMemRefDescriptor(descriptor); +} +Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); +} +void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, + Value v) { + setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); +} +Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, + Location loc) { + return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); +} +void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, + Location loc, Value v) { + setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); +} +namespace { +// Base class for Standard to LLVM IR op conversions. Matches the Op type +// provided as template argument. Carries a reference to the LLVM dialect in +// case it is necessary for rewriters. +template <typename SourceOp> +class LLVMLegalizationPattern : public LLVMOpLowering { +public: + // Construct a conversion pattern. + explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, + LLVMTypeConverter &lowering_) + : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), + lowering_), + dialect(dialect_) {} + + // Get the LLVM IR dialect. + LLVM::LLVMDialect &getDialect() const { return dialect; } + // Get the LLVM context. + llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); } + // Get the LLVM module in which the types are constructed. + llvm::Module &getModule() const { return dialect.getLLVMModule(); } + + // Get the MLIR type wrapping the LLVM integer type whose bit width is defined + // by the pointer size used in the LLVM module. + LLVM::LLVMType getIndexType() const { + return LLVM::LLVMType::getIntNTy( + &dialect, getModule().getDataLayout().getPointerSizeInBits()); + } + + LLVM::LLVMType getVoidType() const { + return LLVM::LLVMType::getVoidTy(&dialect); + } + + // Get the MLIR type wrapping the LLVM i8* type. + LLVM::LLVMType getVoidPtrType() const { + return LLVM::LLVMType::getInt8PtrTy(&dialect); + } + + // Create an LLVM IR pseudo-operation defining the given index constant. + Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, + uint64_t value) const { + return createIndexAttrConstant(builder, loc, getIndexType(), value); + } + +protected: + LLVM::LLVMDialect &dialect; +}; + +struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { + using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast<FuncOp>(op); + FunctionType type = funcOp.getType(); + + // Store the positions of memref-typed arguments so that we can emit loads + // from them to follow the calling convention. + SmallVector<unsigned, 4> promotedArgIndices; + promotedArgIndices.reserve(type.getNumInputs()); + for (auto en : llvm::enumerate(type.getInputs())) { + if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>()) + promotedArgIndices.push_back(en.index()); + } + + // Convert the original function arguments. Struct arguments are promoted to + // pointer to struct arguments to allow calling external functions with + // various ABIs (e.g. compiled from C/C++ on platform X). + auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs"); + TypeConverter::SignatureConversion result(funcOp.getNumArguments()); + auto llvmType = lowering.convertFunctionSignature( + funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); + + // Only retain those attributes that are not constructed by build. + SmallVector<NamedAttribute, 4> attributes; + for (const auto &attr : funcOp.getAttrs()) { + if (attr.first.is(SymbolTable::getSymbolAttrName()) || + attr.first.is(impl::getTypeAttrName()) || + attr.first.is("std.varargs")) + continue; + attributes.push_back(attr); + } + + // Create an LLVM function, use external linkage by default until MLIR + // functions have linkage. + auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( + op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, + attributes); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + + // Tell the rewriter to convert the region signature. + rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + + // Insert loads from memref descriptor pointers in function bodies. + if (!newFuncOp.getBody().empty()) { + Block *firstBlock = &newFuncOp.getBody().front(); + rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); + for (unsigned idx : promotedArgIndices) { + BlockArgument arg = firstBlock->getArgument(idx); + Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg); + rewriter.replaceUsesOfBlockArgument(arg, loaded); + } + } + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +//////////////// Support for Lowering operations on n-D vectors //////////////// +namespace { +// Helper struct to "unroll" operations on n-D vectors in terms of operations on +// 1-D LLVM vectors. +struct NDVectorTypeInfo { + // LLVM array struct which encodes n-D vectors. + LLVM::LLVMType llvmArrayTy; + // LLVM vector type which encodes the inner 1-D vector type. + LLVM::LLVMType llvmVectorTy; + // Multiplicity of llvmArrayTy to llvmVectorTy. + SmallVector<int64_t, 4> arraySizes; +}; +} // namespace + +// For >1-D vector types, extracts the necessary information to iterate over all +// 1-D subvectors in the underlying llrepresentation of the n-D vector +// Iterates on the llvm array type until we hit a non-array type (which is +// asserted to be an llvm vector type). +static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, + LLVMTypeConverter &converter) { + assert(vectorType.getRank() > 1 && "expected >1D vector type"); + NDVectorTypeInfo info; + info.llvmArrayTy = + converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>(); + if (!info.llvmArrayTy) + return info; + info.arraySizes.reserve(vectorType.getRank() - 1); + auto llvmTy = info.llvmArrayTy; + while (llvmTy.isArrayTy()) { + info.arraySizes.push_back(llvmTy.getArrayNumElements()); + llvmTy = llvmTy.getArrayElementType(); + } + if (!llvmTy.isVectorTy()) + return info; + info.llvmVectorTy = llvmTy; + return info; +} + +// Express `linearIndex` in terms of coordinates of `basis`. +// Returns the empty vector when linearIndex is out of the range [0, P] where +// P is the product of all the basis coordinates. +// +// Prerequisites: +// Basis is an array of nonnegative integers (signed type inherited from +// vector shape type). +static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, + unsigned linearIndex) { + SmallVector<int64_t, 4> res; + res.reserve(basis.size()); + for (unsigned basisElement : llvm::reverse(basis)) { + res.push_back(linearIndex % basisElement); + linearIndex = linearIndex / basisElement; + } + if (linearIndex > 0) + return {}; + std::reverse(res.begin(), res.end()); + return res; +} + +// Iterate of linear index, convert to coords space and insert splatted 1-D +// vector in each position. +template <typename Lambda> +void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, + Lambda fun) { + unsigned ub = 1; + for (auto s : info.arraySizes) + ub *= s; + for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { + auto coords = getCoordinates(info.arraySizes, linearIndex); + // Linear index is out of bounds, we are done. + if (coords.empty()) + break; + assert(coords.size() == info.arraySizes.size()); + auto position = builder.getI64ArrayAttr(coords); + fun(position); + } +} +////////////// End Support for Lowering operations on n-D vectors ////////////// + +// Basic lowering implementation for one-to-one rewriting from Standard Ops to +// LLVM Dialect Ops. +template <typename SourceOp, typename TargetOp> +struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { + using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; + using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>; + + // Convert the type of the result to an LLVM type, pass operands as is, + // preserve attributes. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = this->lowering.packFunctionResults( + llvm::to_vector<4>(op->getResultTypes())); + if (!packedType) + return this->matchFailure(); + } + + auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), this->matchSuccess(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + this->matchSuccess(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector<Value, 4> results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = this->lowering.convertType(op->getResult(i)->getType()); + results.push_back(rewriter.create<LLVM::ExtractValueOp>( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return this->matchSuccess(); + } +}; + +template <typename SourceOp, unsigned OpCount> struct OpCountValidator { + static_assert( + std::is_base_of< + typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>, + SourceOp>::value, + "wrong operand count"); +}; + +template <typename SourceOp> struct OpCountValidator<SourceOp, 1> { + static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value, + "expected a single operand"); +}; + +template <typename SourceOp, unsigned OpCount> void ValidateOpCount() { + OpCountValidator<SourceOp, OpCount>(); +} + +// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect +// Ops for N-ary ops with one result. This supports higher-dimensional vector +// types. +template <typename SourceOp, typename TargetOp, unsigned OpCount> +struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { + using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; + using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>; + + // Convert the type of the result to an LLVM type, pass operands as is, + // preserve attributes. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + ValidateOpCount<SourceOp, OpCount>(); + static_assert( + std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, + "expected single result op"); + static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, + SourceOp>::value, + "expected same operands and result type"); + + // Cannot convert ops if their operands are not of LLVM type. + for (Value operand : operands) { + if (!operand || !operand->getType().isa<LLVM::LLVMType>()) + return this->matchFailure(); + } + + auto loc = op->getLoc(); + auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>(); + + if (!llvmArrayTy.isArrayTy()) { + auto newOp = rewriter.create<TargetOp>( + op->getLoc(), operands[0]->getType(), operands, op->getAttrs()); + rewriter.replaceOp(op, newOp.getResult()); + return this->matchSuccess(); + } + + auto vectorType = op->getResult(0)->getType().dyn_cast<VectorType>(); + if (!vectorType) + return this->matchFailure(); + auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering); + auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; + if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) + return this->matchFailure(); + + Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + // For this unrolled `position` corresponding to the `linearIndex`^th + // element, extract operand vectors + SmallVector<Value, OpCount> extractedOperands; + for (unsigned i = 0; i < OpCount; ++i) { + extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( + loc, llvmVectorTy, operands[i], position)); + } + Value newVal = rewriter.create<TargetOp>( + loc, llvmVectorTy, extractedOperands, op->getAttrs()); + desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, + newVal, position); + }); + rewriter.replaceOp(op, desc); + return this->matchSuccess(); + } +}; + +template <typename SourceOp, typename TargetOp> +using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>; +template <typename SourceOp, typename TargetOp> +using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>; + +// Specific lowerings. +// FIXME: this should be tablegen'ed. +struct AbsFOpLowering : public UnaryOpLLVMOpLowering<AbsFOp, LLVM::FAbsOp> { + using Super::Super; +}; +struct CeilFOpLowering : public UnaryOpLLVMOpLowering<CeilFOp, LLVM::FCeilOp> { + using Super::Super; +}; +struct CosOpLowering : public UnaryOpLLVMOpLowering<CosOp, LLVM::CosOp> { + using Super::Super; +}; +struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> { + using Super::Super; +}; +struct LogOpLowering : public UnaryOpLLVMOpLowering<LogOp, LLVM::LogOp> { + using Super::Super; +}; +struct Log10OpLowering : public UnaryOpLLVMOpLowering<Log10Op, LLVM::Log10Op> { + using Super::Super; +}; +struct Log2OpLowering : public UnaryOpLLVMOpLowering<Log2Op, LLVM::Log2Op> { + using Super::Super; +}; +struct NegFOpLowering : public UnaryOpLLVMOpLowering<NegFOp, LLVM::FNegOp> { + using Super::Super; +}; +struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> { + using Super::Super; +}; +struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> { + using Super::Super; +}; +struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> { + using Super::Super; +}; +struct SignedDivIOpLowering + : public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> { + using Super::Super; +}; +struct UnsignedDivIOpLowering + : public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> { + using Super::Super; +}; +struct SignedRemIOpLowering + : public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> { + using Super::Super; +}; +struct UnsignedRemIOpLowering + : public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> { + using Super::Super; +}; +struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> { + using Super::Super; +}; +struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> { + using Super::Super; +}; +struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> { + using Super::Super; +}; +struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> { + using Super::Super; +}; +struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> { + using Super::Super; +}; +struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> { + using Super::Super; +}; +struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> { + using Super::Super; +}; +struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> { + using Super::Super; +}; +struct CopySignOpLowering + : public BinaryOpLLVMOpLowering<CopySignOp, LLVM::CopySignOp> { + using Super::Super; +}; +struct SelectOpLowering + : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> { + using Super::Super; +}; +struct ConstLLVMOpLowering + : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> { + using Super::Super; +}; +struct ShiftLeftOpLowering + : public OneToOneLLVMOpLowering<ShiftLeftOp, LLVM::ShlOp> { + using Super::Super; +}; +struct SignedShiftRightOpLowering + : public OneToOneLLVMOpLowering<SignedShiftRightOp, LLVM::AShrOp> { + using Super::Super; +}; +struct UnsignedShiftRightOpLowering + : public OneToOneLLVMOpLowering<UnsignedShiftRightOp, LLVM::LShrOp> { + using Super::Super; +}; + +// Check if the MemRefType `type` is supported by the lowering. We currently +// only support memrefs with identity maps. +static bool isSupportedMemRefType(MemRefType type) { + return type.getAffineMaps().empty() || + llvm::all_of(type.getAffineMaps(), + [](AffineMap map) { return map.isIdentity(); }); +} + +// An `alloc` is converted into a definition of a memref descriptor value and +// a call to `malloc` to allocate the underlying data buffer. The memref +// descriptor is of the LLVM structure type where: +// 1. the first element is a pointer to the allocated (typed) data buffer, +// 2. the second element is a pointer to the (typed) payload, aligned to the +// specified alignment, +// 3. the remaining elements serve to store all the sizes and strides of the +// memref using LLVM-converted `index` type. +// +// Alignment is obtained by allocating `alignment - 1` more bytes than requested +// and shifting the aligned pointer relative to the allocated memory. If +// alignment is unspecified, the two pointers are equal. +struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { + using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern; + + AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter, + bool useAlloca = false) + : LLVMLegalizationPattern<AllocOp>(dialect_, converter), + useAlloca(useAlloca) {} + + PatternMatchResult match(Operation *op) const override { + MemRefType type = cast<AllocOp>(op).getType(); + if (isSupportedMemRefType(type)) + return matchSuccess(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + if (failed(successStrides)) + return matchFailure(); + + // Dynamic strides are ok if they can be deduced from dynamic sizes (which + // is guaranteed when succeeded(successStrides)). Dynamic offset however can + // never be alloc'ed. + if (offset == MemRefType::getDynamicStrideOrOffset()) + return matchFailure(); + + return matchSuccess(); + } + + void rewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto allocOp = cast<AllocOp>(op); + MemRefType type = allocOp.getType(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector<Value, 4> sizes; + sizes.reserve(type.getRank()); + unsigned i = 0; + for (int64_t s : type.getShape()) + sizes.push_back(s == -1 ? operands[i++] + : createIndexConstant(rewriter, loc, s)); + if (sizes.empty()) + sizes.push_back(createIndexConstant(rewriter, loc, 1)); + + // Compute the total number of memref elements. + Value cumulativeSize = sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSize = rewriter.create<LLVM::MulOp>( + loc, getIndexType(), ArrayRef<Value>{cumulativeSize, sizes[i]}); + + // Compute the size of an individual element. This emits the MLIR equivalent + // of the following sizeof(...) implementation in LLVM IR: + // %0 = getelementptr %elementType* null, %indexType 1 + // %1 = ptrtoint %elementType* %0 to %indexType + // which is a common pattern of getting the size of a type in bytes. + auto elementType = type.getElementType(); + auto convertedPtrType = + lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo(); + auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); + auto one = createIndexConstant(rewriter, loc, 1); + auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, + ArrayRef<Value>{nullPtr, one}); + auto elementSize = + rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); + cumulativeSize = rewriter.create<LLVM::MulOp>( + loc, getIndexType(), ArrayRef<Value>{cumulativeSize, elementSize}); + + // Allocate the underlying buffer and store a pointer to it in the MemRef + // descriptor. + Value allocated = nullptr; + int alignment = 0; + Value alignmentValue = nullptr; + if (auto alignAttr = allocOp.alignment()) + alignment = alignAttr.getValue().getSExtValue(); + + if (useAlloca) { + allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(), + cumulativeSize, alignment); + } else { + // Insert the `malloc` declaration if it is not already present. + auto module = op->getParentOfType<ModuleOp>(); + auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc"); + if (!mallocFunc) { + OpBuilder moduleBuilder( + op->getParentOfType<ModuleOp>().getBodyRegion()); + mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>( + rewriter.getUnknownLoc(), "malloc", + LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), + /*isVarArg=*/false)); + } + if (alignment != 0) { + alignmentValue = createIndexConstant(rewriter, loc, alignment); + cumulativeSize = rewriter.create<LLVM::SubOp>( + loc, + rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue), + one); + } + allocated = rewriter + .create<LLVM::CallOp>( + loc, getVoidPtrType(), + rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) + .getResult(0); + } + + auto structElementType = lowering.convertType(elementType); + auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo( + type.getMemorySpace()); + Value bitcastAllocated = rewriter.create<LLVM::BitcastOp>( + loc, elementPtrType, ArrayRef<Value>(allocated)); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + assert(offset != MemRefType::getDynamicStrideOrOffset() && + "unexpected dynamic offset"); + + // 0-D memref corner case: they have size 1 ... + assert(((type.getRank() == 0 && strides.empty() && sizes.size() == 1) || + (strides.size() == sizes.size())) && + "unexpected number of strides"); + + // Create the MemRef descriptor. + auto structType = lowering.convertType(type); + auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); + // Field 1: Allocated pointer, used for malloc/free. + memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); + + // Field 2: Actual aligned pointer to payload. + Value bitcastAligned = bitcastAllocated; + if (!useAlloca && alignment != 0) { + assert(alignmentValue); + // offset = (align - (ptr % align))% align + Value intVal = rewriter.create<LLVM::PtrToIntOp>( + loc, this->getIndexType(), allocated); + Value ptrModAlign = + rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue); + Value subbed = + rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign); + Value offset = rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue); + Value aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(), + allocated, offset); + bitcastAligned = rewriter.create<LLVM::BitcastOp>( + loc, elementPtrType, ArrayRef<Value>(aligned)); + } + memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); + + // Field 3: Offset in aligned pointer. + memRefDescriptor.setOffset(rewriter, loc, + createIndexConstant(rewriter, loc, offset)); + + if (type.getRank() == 0) + // No size/stride descriptor in memref, return the descriptor value. + return rewriter.replaceOp(op, {memRefDescriptor}); + + // Fields 4 and 5: Sizes and strides of the strided MemRef. + // Store all sizes in the descriptor. Only dynamic sizes are passed in as + // operands to AllocOp. + Value runningStride = nullptr; + // Iterate strides in reverse order, compute runningStride and strideValues. + auto nStrides = strides.size(); + SmallVector<Value, 4> strideValues(nStrides, nullptr); + for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { + int64_t index = nStrides - 1 - indexedStride.index(); + if (strides[index] == MemRefType::getDynamicStrideOrOffset()) + // Identity layout map is enforced in the match function, so we compute: + // `runningStride *= sizes[index]` + runningStride = + runningStride + ? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index]) + : createIndexConstant(rewriter, loc, 1); + else + runningStride = createIndexConstant(rewriter, loc, strides[index]); + strideValues[index] = runningStride; + } + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(sizes)) { + int64_t index = indexedSize.index(); + memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); + } + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + } + + bool useAlloca; +}; + +// A CallOp automatically promotes MemRefType to a sequence of alloca/store and +// passes the pointer to the MemRef across function boundaries. +template <typename CallOpType> +struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> { + using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern; + using Super = CallOpInterfaceLowering<CallOpType>; + using Base = LLVMLegalizationPattern<CallOpType>; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor<CallOpType> transformed(operands); + auto callOp = cast<CallOpType>(op); + + // Pack the result types into a struct. + Type packedResult; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + for (Type resType : resultTypes) { + assert(!resType.isa<UnrankedMemRefType>() && + "Returning unranked memref is not supported. Pass result as an" + "argument instead."); + (void)resType; + } + + if (numResults != 0) { + if (!(packedResult = this->lowering.packFunctionResults(resultTypes))) + return this->matchFailure(); + } + + auto promoted = this->lowering.promoteMemRefDescriptors( + op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter); + auto newOp = rewriter.create<LLVM::CallOp>(op->getLoc(), packedResult, + promoted, op->getAttrs()); + + // If < 2 results, packing did not do anything and we can just return. + if (numResults < 2) { + rewriter.replaceOp(op, newOp.getResults()); + return this->matchSuccess(); + } + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around + // a particular interaction between MemRefType and CallOp lowering. Find a + // way to avoid special casing. + SmallVector<Value, 4> results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = this->lowering.convertType(op->getResult(i)->getType()); + results.push_back(rewriter.create<LLVM::ExtractValueOp>( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + + return this->matchSuccess(); + } +}; + +struct CallOpLowering : public CallOpInterfaceLowering<CallOp> { + using Super::Super; +}; + +struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> { + using Super::Super; +}; + +// A `dealloc` is converted into a call to `free` on the underlying data buffer. +// The memref descriptor being an SSA value, there is no need to clean it up +// in any way. +struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { + using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern; + + DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter, + bool useAlloca = false) + : LLVMLegalizationPattern<DeallocOp>(dialect_, converter), + useAlloca(useAlloca) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + if (useAlloca) + return rewriter.eraseOp(op), matchSuccess(); + + assert(operands.size() == 1 && "dealloc takes one operand"); + OperandAdaptor<DeallocOp> transformed(operands); + + // Insert the `free` declaration if it is not already present. + auto freeFunc = + op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free"); + if (!freeFunc) { + OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion()); + freeFunc = moduleBuilder.create<LLVM::LLVMFuncOp>( + rewriter.getUnknownLoc(), "free", + LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), + /*isVarArg=*/false)); + } + + MemRefDescriptor memref(transformed.memref()); + Value casted = rewriter.create<LLVM::BitcastOp>( + op->getLoc(), getVoidPtrType(), + memref.allocatedPtr(rewriter, op->getLoc())); + rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted); + return matchSuccess(); + } + + bool useAlloca; +}; + +// A `tanh` is converted into a call to the `tanh` function. +struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> { + using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + + using LLVMFuncOpT = LLVM::LLVMFuncOp; + using LLVMTypeT = LLVM::LLVMType; + + OperandAdaptor<TanhOp> transformed(operands); + LLVMTypeT operandType = + transformed.operand()->getType().dyn_cast_or_null<LLVM::LLVMType>(); + + if (!operandType) + return matchFailure(); + + std::string functionName; + if (operandType.isFloatTy()) + functionName = "tanhf"; + else if (operandType.isDoubleTy()) + functionName = "tanh"; + else + return matchFailure(); + + // Get a reference to the tanh function, inserting it if necessary. + Operation *tanhFunc = + SymbolTable::lookupNearestSymbolFrom(op, functionName); + + LLVMFuncOpT tanhLLVMFunc; + if (tanhFunc) { + tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc); + } else { + PatternRewriter::InsertionGuard insertGuard(rewriter); + auto module = op->getParentOfType<ModuleOp>(); + rewriter.setInsertionPointToStart(module.getBody()); + tanhLLVMFunc = rewriter.create<LLVMFuncOpT>( + module.getLoc(), functionName, + LLVMTypeT::getFunctionTy(operandType, operandType, + /*isVarArg=*/false)); + } + + rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc), + transformed.operand()); + return matchSuccess(); + } +}; + +struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { + using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern; + + PatternMatchResult match(Operation *op) const override { + auto memRefCastOp = cast<MemRefCastOp>(op); + Type srcType = memRefCastOp.getOperand()->getType(); + Type dstType = memRefCastOp.getType(); + + if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) { + MemRefType sourceType = + memRefCastOp.getOperand()->getType().cast<MemRefType>(); + MemRefType targetType = memRefCastOp.getType().cast<MemRefType>(); + return (isSupportedMemRefType(targetType) && + isSupportedMemRefType(sourceType)) + ? matchSuccess() + : matchFailure(); + } + + // At least one of the operands is unranked type + assert(srcType.isa<UnrankedMemRefType>() || + dstType.isa<UnrankedMemRefType>()); + + // Unranked to unranked cast is disallowed + return !(srcType.isa<UnrankedMemRefType>() && + dstType.isa<UnrankedMemRefType>()) + ? matchSuccess() + : matchFailure(); + } + + void rewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto memRefCastOp = cast<MemRefCastOp>(op); + OperandAdaptor<MemRefCastOp> transformed(operands); + + auto srcType = memRefCastOp.getOperand()->getType(); + auto dstType = memRefCastOp.getType(); + auto targetStructType = lowering.convertType(memRefCastOp.getType()); + auto loc = op->getLoc(); + + if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) { + // memref_cast is defined for source and destination memref types with the + // same element type, same mappings, same address space and same rank. + // Therefore a simple bitcast suffices. If not it is undefined behavior. + rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType, + transformed.source()); + } else if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { + // Casting ranked to unranked memref type + // Set the rank in the destination from the memref type + // Allocate space on the stack and copy the src memref descriptor + // Set the ptr in the destination to the stack space + auto srcMemRefType = srcType.cast<MemRefType>(); + int64_t rank = srcMemRefType.getRank(); + // ptr = AllocaOp sizeof(MemRefDescriptor) + auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(), + rewriter); + // voidptr = BitCastOp srcType* to void* + auto voidPtr = + rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr) + .getResult(); + // rank = ConstantOp srcRank + auto rankVal = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(rewriter.getIntegerType(64)), + rewriter.getI64IntegerAttr(rank)); + // undef = UndefOp + UnrankedMemRefDescriptor memRefDesc = + UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); + // d1 = InsertValueOp undef, rank, 0 + memRefDesc.setRank(rewriter, loc, rankVal); + // d2 = InsertValueOp d1, voidptr, 1 + memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); + rewriter.replaceOp(op, (Value)memRefDesc); + + } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { + // Casting from unranked type to ranked. + // The operation is assumed to be doing a correct cast. If the destination + // type mismatches the unranked the type, it is undefined behavior. + UnrankedMemRefDescriptor memRefDesc(transformed.source()); + // ptr = ExtractValueOp src, 1 + auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); + // castPtr = BitCastOp i8* to structTy* + auto castPtr = + rewriter + .create<LLVM::BitcastOp>( + loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(), + ptr) + .getResult(); + // struct = LoadOp castPtr + auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); + rewriter.replaceOp(op, loadOp.getResult()); + } else { + llvm_unreachable("Unsuppored unranked memref to unranked memref cast"); + } + } +}; + +// A `dim` is converted to a constant for static sizes and to an access to the +// size stored in the memref descriptor for dynamic sizes. +struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { + using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto dimOp = cast<DimOp>(op); + OperandAdaptor<DimOp> transformed(operands); + MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>(); + + auto shape = type.getShape(); + int64_t index = dimOp.getIndex(); + // Extract dynamic size from the memref descriptor. + if (ShapedType::isDynamic(shape[index])) + rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) + .size(rewriter, op->getLoc(), index)}); + else + // Use constant for static size. + rewriter.replaceOp( + op, createIndexConstant(rewriter, op->getLoc(), shape[index])); + return matchSuccess(); + } +}; + +// Common base for load and store operations on MemRefs. Restricts the match +// to supported MemRef types. Provides functionality to emit code accessing a +// specific element of the underlying data buffer. +template <typename Derived> +struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { + using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern; + using Base = LoadStoreOpLowering<Derived>; + + PatternMatchResult match(Operation *op) const override { + MemRefType type = cast<Derived>(op).getMemRefType(); + return isSupportedMemRefType(type) ? this->matchSuccess() + : this->matchFailure(); + } + + // Given subscript indices and array sizes in row-major order, + // i_n, i_{n-1}, ..., i_1 + // s_n, s_{n-1}, ..., s_1 + // obtain a value that corresponds to the linearized subscript + // \sum_k i_k * \prod_{j=1}^{k-1} s_j + // by accumulating the running linearized value. + // Note that `indices` and `allocSizes` are passed in the same order as they + // appear in load/store operations and memref type declarations. + Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, + ArrayRef<Value> indices, + ArrayRef<Value> allocSizes) const { + assert(indices.size() == allocSizes.size() && + "mismatching number of indices and allocation sizes"); + assert(!indices.empty() && "cannot linearize a 0-dimensional access"); + + Value linearized = indices.front(); + for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { + linearized = builder.create<LLVM::MulOp>( + loc, this->getIndexType(), + ArrayRef<Value>{linearized, allocSizes[i]}); + linearized = builder.create<LLVM::AddOp>( + loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]}); + } + return linearized; + } + + // This is a strided getElementPtr variant that linearizes subscripts as: + // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. + Value getStridedElementPtr(Location loc, Type elementTypePtr, + Value descriptor, ArrayRef<Value> indices, + ArrayRef<int64_t> strides, int64_t offset, + ConversionPatternRewriter &rewriter) const { + MemRefDescriptor memRefDescriptor(descriptor); + + Value base = memRefDescriptor.alignedPtr(rewriter, loc); + Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : this->createIndexConstant(rewriter, loc, offset); + + for (int i = 0, e = indices.size(); i < e; ++i) { + Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.stride(rewriter, loc, i) + : this->createIndexConstant(rewriter, loc, strides[i]); + Value additionalOffset = + rewriter.create<LLVM::MulOp>(loc, indices[i], stride); + offsetValue = + rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset); + } + return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue); + } + + Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, + ArrayRef<Value> indices, ConversionPatternRewriter &rewriter, + llvm::Module &module) const { + LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, + offset, rewriter); + } +}; + +// Load operation is lowered to obtaining a pointer to the indexed element +// and loading it. +struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loadOp = cast<LoadOp>(op); + OperandAdaptor<LoadOp> transformed(operands); + auto type = loadOp.getMemRefType(); + + Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); + rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr); + return matchSuccess(); + } +}; + +// Store operation is lowered to obtaining a pointer to the indexed element, +// and storing the given value to it. +struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto type = cast<StoreOp>(op).getMemRefType(); + OperandAdaptor<StoreOp> transformed(operands); + + Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); + rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), + dataPtr); + return matchSuccess(); + } +}; + +// The prefetch operation is lowered in a way similar to the load operation +// except that the llvm.prefetch operation is used for replacement. +struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto prefetchOp = cast<PrefetchOp>(op); + OperandAdaptor<PrefetchOp> transformed(operands); + auto type = prefetchOp.getMemRefType(); + + Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); + + // Replace with llvm.prefetch. + auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32)); + auto isWrite = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmI32Type, + rewriter.getI32IntegerAttr(prefetchOp.isWrite())); + auto localityHint = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmI32Type, + rewriter.getI32IntegerAttr(prefetchOp.localityHint().getZExtValue())); + auto isData = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmI32Type, + rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); + + rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite, + localityHint, isData); + return matchSuccess(); + } +}; + +// The lowering of index_cast becomes an integer conversion since index becomes +// an integer. If the bit width of the source and target integer types is the +// same, just erase the cast. If the target type is wider, sign-extend the +// value, otherwise truncate it. +struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> { + using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + IndexCastOpOperandAdaptor transformed(operands); + auto indexCastOp = cast<IndexCastOp>(op); + + auto targetType = + this->lowering.convertType(indexCastOp.getResult()->getType()) + .cast<LLVM::LLVMType>(); + auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>(); + unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); + unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); + + if (targetBits == sourceBits) + rewriter.replaceOp(op, transformed.in()); + else if (targetBits < sourceBits) + rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, + transformed.in()); + else + rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, + transformed.in()); + return matchSuccess(); + } +}; + +// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two +// enums share the numerical values so just cast. +template <typename LLVMPredType, typename StdPredType> +static LLVMPredType convertCmpPredicate(StdPredType pred) { + return static_cast<LLVMPredType>(pred); +} + +struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> { + using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto cmpiOp = cast<CmpIOp>(op); + CmpIOpOperandAdaptor transformed(operands); + + rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( + op, lowering.convertType(cmpiOp.getResult()->getType()), + rewriter.getI64IntegerAttr(static_cast<int64_t>( + convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))), + transformed.lhs(), transformed.rhs()); + + return matchSuccess(); + } +}; + +struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> { + using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto cmpfOp = cast<CmpFOp>(op); + CmpFOpOperandAdaptor transformed(operands); + + rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( + op, lowering.convertType(cmpfOp.getResult()->getType()), + rewriter.getI64IntegerAttr(static_cast<int64_t>( + convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))), + transformed.lhs(), transformed.rhs()); + + return matchSuccess(); + } +}; + +struct SIToFPLowering + : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> { + using Super::Super; +}; + +struct FPExtLowering : public OneToOneLLVMOpLowering<FPExtOp, LLVM::FPExtOp> { + using Super::Super; +}; + +struct FPTruncLowering + : public OneToOneLLVMOpLowering<FPTruncOp, LLVM::FPTruncOp> { + using Super::Super; +}; + +struct SignExtendIOpLowering + : public OneToOneLLVMOpLowering<SignExtendIOp, LLVM::SExtOp> { + using Super::Super; +}; + +struct TruncateIOpLowering + : public OneToOneLLVMOpLowering<TruncateIOp, LLVM::TruncOp> { + using Super::Super; +}; + +struct ZeroExtendIOpLowering + : public OneToOneLLVMOpLowering<ZeroExtendIOp, LLVM::ZExtOp> { + using Super::Super; +}; + +// Base class for LLVM IR lowering terminator operations with successors. +template <typename SourceOp, typename TargetOp> +struct OneToOneLLVMTerminatorLowering + : public LLVMLegalizationPattern<SourceOp> { + using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; + using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> properOperands, + ArrayRef<Block *> destinations, + ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<ValueRange, 2> operandRanges(operands.begin(), operands.end()); + rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations, + operandRanges, op->getAttrs()); + return this->matchSuccess(); + } +}; + +// Special lowering pattern for `ReturnOps`. Unlike all other operations, +// `ReturnOp` interacts with the function signature and must have as many +// operands as the function has return values. Because in LLVM IR, functions +// can only return 0 or 1 value, we pack multiple values into a structure type. +// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if +// necessary before returning it +struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { + using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + unsigned numArguments = op->getNumOperands(); + + // If ReturnOp has 0 or 1 operand, create it and return immediately. + if (numArguments == 0) { + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, ArrayRef<Value>(), ArrayRef<Block *>(), op->getAttrs()); + return matchSuccess(); + } + if (numArguments == 1) { + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, ArrayRef<Value>(operands.front()), ArrayRef<Block *>(), + op->getAttrs()); + return matchSuccess(); + } + + // Otherwise, we need to pack the arguments into an LLVM struct type before + // returning. + auto packedType = + lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); + + Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); + for (unsigned i = 0; i < numArguments; ++i) { + packed = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), packedType, packed, operands[i], + rewriter.getI64ArrayAttr(i)); + } + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, llvm::makeArrayRef(packed), ArrayRef<Block *>(), op->getAttrs()); + return matchSuccess(); + } +}; + +// FIXME: this should be tablegen'ed as well. +struct BranchOpLowering + : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> { + using Super::Super; +}; +struct CondBranchOpLowering + : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> { + using Super::Super; +}; + +// The Splat operation is lowered to an insertelement + a shufflevector +// operation. Splat to only 1-d vector result types are lowered. +struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> { + using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto splatOp = cast<SplatOp>(op); + VectorType resultType = splatOp.getType().dyn_cast<VectorType>(); + if (!resultType || resultType.getRank() != 1) + return matchFailure(); + + // First insert it into an undef vector so we can shuffle it. + auto vectorType = lowering.convertType(splatOp.getType()); + Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType); + auto zero = rewriter.create<LLVM::ConstantOp>( + op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32))); + + auto v = rewriter.create<LLVM::InsertElementOp>( + op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); + + int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); + SmallVector<int32_t, 4> zeroValues(width, 0); + + // Shuffle the value across the desired number of elements. + ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); + rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs); + return matchSuccess(); + } +}; + +// The Splat operation is lowered to an insertelement + a shufflevector +// operation. Splat to only 2+-d vector result types are lowered by the +// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. +struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> { + using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto splatOp = cast<SplatOp>(op); + OperandAdaptor<SplatOp> adaptor(operands); + VectorType resultType = splatOp.getType().dyn_cast<VectorType>(); + if (!resultType || resultType.getRank() == 1) + return matchFailure(); + + // First insert it into an undef vector so we can shuffle it. + auto loc = op->getLoc(); + auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering); + auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; + auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; + if (!llvmArrayTy || !llvmVectorTy) + return matchFailure(); + + // Construct returned value. + Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); + + // Construct a 1-D vector with the splatted value that we insert in all the + // places within the returned descriptor. + Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy); + auto zero = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32))); + Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc, + adaptor.input(), zero); + + // Shuffle the value across the desired number of elements. + int64_t width = resultType.getDimSize(resultType.getRank() - 1); + SmallVector<int32_t, 4> zeroValues(width, 0); + ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); + v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs); + + // Iterate of linear index, convert to coords space and insert splatted 1-D + // vector in each position. + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v, + position); + }); + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +/// Conversion pattern that transforms a subview op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The subview op is replaced by the descriptor. +struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { + using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto viewOp = cast<SubViewOp>(op); + // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support + // having multiple variadic operands where each operand can have different + // number of entries, clean all of this up. + SmallVector<Value, 2> dynamicOffsets( + std::next(operands.begin()), + std::next(operands.begin(), 1 + viewOp.getNumOffsets())); + SmallVector<Value, 2> dynamicSizes( + std::next(operands.begin(), 1 + viewOp.getNumOffsets()), + std::next(operands.begin(), + 1 + viewOp.getNumOffsets() + viewOp.getNumSizes())); + SmallVector<Value, 2> dynamicStrides( + std::next(operands.begin(), + 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()), + operands.end()); + + auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>(); + auto sourceElementTy = + lowering.convertType(sourceMemRefType.getElementType()) + .dyn_cast_or_null<LLVM::LLVMType>(); + + auto viewMemRefType = viewOp.getType(); + auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) + .dyn_cast<LLVM::LLVMType>(); + auto targetDescTy = + lowering.convertType(viewMemRefType).dyn_cast_or_null<LLVM::LLVMType>(); + if (!sourceElementTy || !targetDescTy) + return matchFailure(); + + // Currently, only rank > 0 and full or no operands are supported. Fail to + // convert otherwise. + unsigned rank = sourceMemRefType.getRank(); + if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) || + (!dynamicSizes.empty() && rank != dynamicSizes.size()) || + (!dynamicStrides.empty() && rank != dynamicStrides.size())) + return matchFailure(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + if (failed(successStrides)) + return matchFailure(); + + // Create the descriptor. + MemRefDescriptor sourceMemRef(operands.front()); + auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); + + // Copy the buffer pointer from the old descriptor to the new one. + Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); + Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), extracted); + targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); + + extracted = sourceMemRef.alignedPtr(rewriter, loc); + bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), extracted); + targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); + + // Extract strides needed to compute offset. + SmallVector<Value, 4> strideValues; + strideValues.reserve(viewMemRefType.getRank()); + for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) + strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); + + // Fill in missing dynamic sizes. + auto llvmIndexType = lowering.convertType(rewriter.getIndexType()); + if (dynamicSizes.empty()) { + dynamicSizes.reserve(viewMemRefType.getRank()); + auto shape = viewMemRefType.getShape(); + for (auto extent : shape) { + dynamicSizes.push_back(rewriter.create<LLVM::ConstantOp>( + loc, llvmIndexType, rewriter.getI64IntegerAttr(extent))); + } + } + + // Offset. + Value baseOffset = sourceMemRef.offset(rewriter, loc); + for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + Value min = dynamicOffsets[i]; + baseOffset = rewriter.create<LLVM::AddOp>( + loc, baseOffset, + rewriter.create<LLVM::MulOp>(loc, min, strideValues[i])); + } + targetMemRef.setOffset(rewriter, loc, baseOffset); + + // Update sizes and strides. + for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); + Value newStride; + if (dynamicStrides.empty()) + newStride = rewriter.create<LLVM::ConstantOp>( + loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); + else + newStride = rewriter.create<LLVM::MulOp>(loc, dynamicStrides[i], + strideValues[i]); + targetMemRef.setStride(rewriter, loc, i, newStride); + } + + rewriter.replaceOp(op, {targetMemRef}); + return matchSuccess(); + } +}; + +/// Conversion pattern that transforms a op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The view op is replaced by the descriptor. +struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { + using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern; + + // Build and return the value for the idx^th shape dimension, either by + // returning the constant shape dimension or counting the proper dynamic size. + Value getSize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> shape, ArrayRef<Value> dynamicSizes, + unsigned idx) const { + assert(idx < shape.size()); + if (!ShapedType::isDynamic(shape[idx])) + return createIndexConstant(rewriter, loc, shape[idx]); + // Count the number of dynamic dims in range [0, idx] + unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { + return ShapedType::isDynamic(v); + }); + return dynamicSizes[nDynamic]; + } + + // Build and return the idx^th stride, either by returning the constant stride + // or by computing the dynamic stride from the current `runningStride` and + // `nextSize`. The caller should keep a running stride and update it with the + // result returned by this function. + Value getStride(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> strides, Value nextSize, + Value runningStride, unsigned idx) const { + assert(idx < strides.size()); + if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) + return createIndexConstant(rewriter, loc, strides[idx]); + if (nextSize) + return runningStride + ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) + : nextSize; + assert(!runningStride); + return createIndexConstant(rewriter, loc, 1); + } + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto viewOp = cast<ViewOp>(op); + ViewOpOperandAdaptor adaptor(operands); + + auto viewMemRefType = viewOp.getType(); + auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) + .dyn_cast<LLVM::LLVMType>(); + auto targetDescTy = + lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>(); + if (!targetDescTy) + return op->emitWarning("Target descriptor type not converted to LLVM"), + matchFailure(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + if (failed(successStrides)) + return op->emitWarning("cannot cast to non-strided shape"), + matchFailure(); + + // Create the descriptor. + MemRefDescriptor sourceMemRef(adaptor.source()); + auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); + + // Field 1: Copy the allocated pointer, used for malloc/free. + Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); + Value bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), extracted); + targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); + + // Field 2: Copy the actual aligned pointer to payload. + extracted = sourceMemRef.alignedPtr(rewriter, loc); + bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), extracted); + targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); + + // Field 3: Copy the offset in aligned pointer. + unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); + (void)numDynamicSizes; + bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset(); + auto sizeAndOffsetOperands = adaptor.operands(); + assert(llvm::size(sizeAndOffsetOperands) == + numDynamicSizes + (hasDynamicOffset ? 1 : 0)); + Value baseOffset = !hasDynamicOffset + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.front(); + targetMemRef.setOffset(rewriter, loc, baseOffset); + + // Early exit for 0-D corner case. + if (viewMemRefType.getRank() == 0) + return rewriter.replaceOp(op, {targetMemRef}), matchSuccess(); + + // Fields 4 and 5: Update sizes and strides. + if (strides.back() != 1) + return op->emitWarning("cannot cast to non-contiguous shape"), + matchFailure(); + Value stride = nullptr, nextSize = nullptr; + // Drop the dynamic stride from the operand list, if present. + ArrayRef<Value> sizeOperands(sizeAndOffsetOperands); + if (hasDynamicOffset) + sizeOperands = sizeOperands.drop_front(); + for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + // Update size. + Value size = + getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); + targetMemRef.setSize(rewriter, loc, i, size); + // Update stride. + stride = getStride(rewriter, loc, strides, nextSize, stride, i); + targetMemRef.setStride(rewriter, loc, i, stride); + nextSize = size; + } + + rewriter.replaceOp(op, {targetMemRef}); + return matchSuccess(); + } +}; + +} // namespace + +static void ensureDistinctSuccessors(Block &bb) { + auto *terminator = bb.getTerminator(); + + // Find repeated successors with arguments. + llvm::SmallDenseMap<Block *, SmallVector<int, 4>> successorPositions; + for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) { + Block *successor = terminator->getSuccessor(i); + // Blocks with no arguments are safe even if they appear multiple times + // because they don't need PHI nodes. + if (successor->getNumArguments() == 0) + continue; + successorPositions[successor].push_back(i); + } + + // If a successor appears for the second or more time in the terminator, + // create a new dummy block that unconditionally branches to the original + // destination, and retarget the terminator to branch to this new block. + // There is no need to pass arguments to the dummy block because it will be + // dominated by the original block and can therefore use any values defined in + // the original block. + for (const auto &successor : successorPositions) { + const auto &positions = successor.second; + // Start from the second occurrence of a block in the successor list. + for (auto position = std::next(positions.begin()), end = positions.end(); + position != end; ++position) { + auto *dummyBlock = new Block(); + bb.getParent()->push_back(dummyBlock); + auto builder = OpBuilder(dummyBlock); + SmallVector<Value, 8> operands( + terminator->getSuccessorOperands(*position)); + builder.create<BranchOp>(terminator->getLoc(), successor.first, operands); + terminator->setSuccessor(dummyBlock, *position); + for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e; + ++i) + terminator->eraseSuccessorOperand(*position, i); + } + } +} + +void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { + for (auto f : m.getOps<FuncOp>()) { + for (auto &bb : f.getBlocks()) { + ::ensureDistinctSuccessors(bb); + } + } +} + +/// Collect a set of patterns to convert from the Standard dialect to LLVM. +void mlir::populateStdToLLVMNonMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // FIXME: this should be tablegen'ed + // clang-format off + patterns.insert< + AbsFOpLowering, + AddFOpLowering, + AddIOpLowering, + AndOpLowering, + BranchOpLowering, + CallIndirectOpLowering, + CallOpLowering, + CeilFOpLowering, + CmpFOpLowering, + CmpIOpLowering, + CondBranchOpLowering, + CopySignOpLowering, + CosOpLowering, + ConstLLVMOpLowering, + DivFOpLowering, + ExpOpLowering, + LogOpLowering, + Log10OpLowering, + Log2OpLowering, + FPExtLowering, + FPTruncLowering, + IndexCastOpLowering, + MulFOpLowering, + MulIOpLowering, + NegFOpLowering, + OrOpLowering, + PrefetchOpLowering, + RemFOpLowering, + ReturnOpLowering, + SIToFPLowering, + SelectOpLowering, + ShiftLeftOpLowering, + SignExtendIOpLowering, + SignedDivIOpLowering, + SignedRemIOpLowering, + SignedShiftRightOpLowering, + SplatOpLowering, + SplatNdOpLowering, + SubFOpLowering, + SubIOpLowering, + TanhOpLowering, + TruncateIOpLowering, + UnsignedDivIOpLowering, + UnsignedRemIOpLowering, + UnsignedShiftRightOpLowering, + XOrOpLowering, + ZeroExtendIOpLowering>(*converter.getDialect(), converter); + // clang-format on +} + +void mlir::populateStdToLLVMMemoryConversionPatters( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // clang-format off + patterns.insert< + DimOpLowering, + FuncOpConversion, + LoadOpLowering, + MemRefCastOpLowering, + StoreOpLowering, + SubViewOpLowering, + ViewOpLowering>(*converter.getDialect(), converter); + patterns.insert< + AllocOpLowering, + DeallocOpLowering>( + *converter.getDialect(), converter, clUseAlloca.getValue()); + // clang-format on +} + +void mlir::populateStdToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns); +} + +// Convert types using the stored LLVM IR module. +Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } + +// Create an LLVM IR structure type if there is more than one result. +Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) { + assert(!types.empty() && "expected non-empty list of type"); + + if (types.size() == 1) + return convertType(types.front()); + + SmallVector<LLVM::LLVMType, 8> resultTypes; + resultTypes.reserve(types.size()); + for (auto t : types) { + auto converted = convertType(t).dyn_cast<LLVM::LLVMType>(); + if (!converted) + return {}; + resultTypes.push_back(converted); + } + + return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); +} + +Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, + OpBuilder &builder) { + auto *context = builder.getContext(); + auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); + auto indexType = IndexType::get(context); + // Alloca with proper alignment. We do not expect optimizations of this + // alloca op and so we omit allocating at the entry block. + auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo(); + Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty, + IntegerAttr::get(indexType, 1)); + Value allocated = + builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0); + // Store into the alloca'ed descriptor. + builder.create<LLVM::StoreOp>(loc, operand, allocated); + return allocated; +} + +SmallVector<Value, 4> +LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, + ValueRange operands, + OpBuilder &builder) { + SmallVector<Value, 4> promotedOperands; + promotedOperands.reserve(operands.size()); + for (auto it : llvm::zip(opOperands, operands)) { + auto operand = std::get<0>(it); + auto llvmOperand = std::get<1>(it); + if (!operand->getType().isa<MemRefType>() && + !operand->getType().isa<UnrankedMemRefType>()) { + promotedOperands.push_back(operand); + continue; + } + promotedOperands.push_back( + promoteOneMemRefDescriptor(loc, llvmOperand, builder)); + } + return promotedOperands; +} + +/// Create an instance of LLVMTypeConverter in the given context. +static std::unique_ptr<LLVMTypeConverter> +makeStandardToLLVMTypeConverter(MLIRContext *context) { + return std::make_unique<LLVMTypeConverter>(context); +} + +namespace { +/// A pass converting MLIR operations into the LLVM IR dialect. +struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { + // By default, the patterns are those converting Standard operations to the + // LLVMIR dialect. + explicit LLVMLoweringPass( + bool useAlloca = false, + LLVMPatternListFiller patternListFiller = + populateStdToLLVMConversionPatterns, + LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) + : patternListFiller(patternListFiller), + typeConverterMaker(converterBuilder) {} + + // Run the dialect converter on the module. + void runOnModule() override { + if (!typeConverterMaker || !patternListFiller) + return signalPassFailure(); + + ModuleOp m = getModule(); + LLVM::ensureDistinctSuccessors(m); + std::unique_ptr<LLVMTypeConverter> typeConverter = + typeConverterMaker(&getContext()); + if (!typeConverter) + return signalPassFailure(); + + OwningRewritePatternList patterns; + populateLoopToStdConversionPatterns(patterns, m.getContext()); + patternListFiller(*typeConverter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + if (failed(applyPartialConversion(m, target, patterns, &*typeConverter))) + signalPassFailure(); + } + + // Callback for creating a list of patterns. It is called every time in + // runOnModule since applyPartialConversion consumes the list. + LLVMPatternListFiller patternListFiller; + + // Callback for creating an instance of type converter. The converter + // constructor needs an MLIRContext, which is not available until runOnModule. + LLVMTypeConverterMaker typeConverterMaker; +}; +} // end namespace + +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createLowerToLLVMPass(bool useAlloca) { + return std::make_unique<LLVMLoweringPass>(useAlloca); +} + +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, + LLVMTypeConverterMaker typeConverterMaker, + bool useAlloca) { + return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller, + typeConverterMaker); +} + +static PassRegistration<LLVMLoweringPass> + pass("convert-std-to-llvm", + "Convert scalar and vector operations from the " + "Standard to the LLVM dialect", + [] { + return std::make_unique<LLVMLoweringPass>( + clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, + makeStandardToLLVMTypeConverter); + }); diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt new file mode 100644 index 00000000000..fcced23a95e --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -0,0 +1,26 @@ +set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td) +mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRStandardToSPIRVIncGen) + +add_llvm_library(MLIRStandardToSPIRVTransforms + ConvertStandardToSPIRV.cpp + ConvertStandardToSPIRVPass.cpp + LegalizeStandardForSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + ) + +add_dependencies(MLIRStandardToSPIRVTransforms + MLIRStandardToSPIRVIncGen) + +target_link_libraries(MLIRStandardToSPIRVTransforms + MLIRIR + MLIRPass + MLIRSPIRV + MLIRSupport + MLIRTransformUtils + MLIRSPIRV + MLIRStandardOps + ) diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp new file mode 100644 index 00000000000..a02dee4419a --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -0,0 +1,314 @@ +//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Standard Ops to the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/LayoutUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineMap.h" +#include "llvm/ADT/SetVector.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Convert constant operation with IndexType return to SPIR-V constant +/// operation. Since IndexType is not used within SPIR-V dialect, this needs +/// special handling to make sure the result type and the type of the value +/// attribute are consistent. +// TODO(ravishankarm) : This should be moved into DRR. +class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> { +public: + using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert compare operation to SPIR-V dialect. +class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> { +public: + using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert integer binary operations to SPIR-V operations. Cannot use +/// tablegen for this. If the integer operation is on variables of IndexType, +/// the type of the return value of the replacement operation differs from +/// that of the replaced operation. This is not handled in tablegen-based +/// pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. +template <typename StdOp, typename SPIRVOp> +class IntegerOpConversion final : public SPIRVOpLowering<StdOp> { +public: + using SPIRVOpLowering<StdOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(StdOp operation, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto resultType = + this->typeConverter.convertType(operation.getResult()->getType()); + rewriter.template replaceOpWithNewOp<SPIRVOp>( + operation, resultType, operands, ArrayRef<NamedAttribute>()); + return this->matchSuccess(); + } +}; + +/// Convert load -> spv.LoadOp. The operands of the replaced operation are of +/// IndexType while that of the replacement operation are of type i32. This is +/// not supported in tablegen based pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. +class LoadOpConversion final : public SPIRVOpLowering<LoadOp> { +public: + using SPIRVOpLowering<LoadOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert return -> spv.Return. +// TODO(ravishankarm) : This should be moved into DRR. +class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> { +public: + using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert select -> spv.Select +// TODO(ravishankarm) : This should be moved into DRR. +class SelectOpConversion final : public SPIRVOpLowering<SelectOp> { +public: + using SPIRVOpLowering<SelectOp>::SPIRVOpLowering; + PatternMatchResult + matchAndRewrite(SelectOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Convert store -> spv.StoreOp. The operands of the replaced operation are +/// of IndexType while that of the replacement operation are of type i32. This +/// is not supported in tablegen based pattern specification. +// TODO(ravishankarm) : This should be moved into DRR. +class StoreOpConversion final : public SPIRVOpLowering<StoreOp> { +public: + using SPIRVOpLowering<StoreOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Utility functions for operation conversion +//===----------------------------------------------------------------------===// + +/// Performs the index computation to get to the element pointed to by +/// `indices` using the layout map of `baseType`. + +// TODO(ravishankarm) : This method assumes that the `origBaseType` is a +// MemRefType with AffineMap that has static strides. Handle dynamic strides +spirv::AccessChainOp getElementPtr(OpBuilder &builder, + SPIRVTypeConverter &typeConverter, + Location loc, MemRefType origBaseType, + Value basePtr, ArrayRef<Value> indices) { + // Get base and offset of the MemRefType and verify they are static. + int64_t offset; + SmallVector<int64_t, 4> strides; + if (failed(getStridesAndOffset(origBaseType, strides, offset)) || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + return nullptr; + } + + auto indexType = typeConverter.getIndexType(builder.getContext()); + + Value ptrLoc = nullptr; + assert(indices.size() == strides.size()); + for (auto index : enumerate(indices)) { + Value strideVal = builder.create<spirv::ConstantOp>( + loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); + ptrLoc = + (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult() + : update); + } + SmallVector<Value, 2> linearizedIndices; + // Add a '0' at the start to index into the struct. + linearizedIndices.push_back(builder.create<spirv::ConstantOp>( + loc, indexType, IntegerAttr::get(indexType, 0))); + linearizedIndices.push_back(ptrLoc); + return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); +} + +//===----------------------------------------------------------------------===// +// ConstantOp with index type. +//===----------------------------------------------------------------------===// + +PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( + ConstantOp constIndexOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (!constIndexOp.getResult()->getType().isa<IndexType>()) { + return matchFailure(); + } + // The attribute has index type which is not directly supported in + // SPIR-V. Get the integer value and create a new IntegerAttr. + auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>(); + if (!constAttr) { + return matchFailure(); + } + + // Use the bitwidth set in the value attribute to decide the result type + // of the SPIR-V constant operation since SPIR-V does not support index + // types. + auto constVal = constAttr.getValue(); + auto constValType = constAttr.getType().dyn_cast<IndexType>(); + if (!constValType) { + return matchFailure(); + } + auto spirvConstType = + typeConverter.convertType(constIndexOp.getResult()->getType()); + auto spirvConstVal = + rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); + rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType, + spirvConstVal); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + CmpIOpOperandAdaptor cmpIOpOperands(operands); + + switch (cmpIOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp<spirvOp>( \ + cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \ + cmpIOpOperands.rhs()); \ + return matchSuccess(); + + DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + +#undef DISPATCH + + default: + break; + } + return matchFailure(); +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + LoadOpOperandAdaptor loadOperands(operands); + auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), + loadOp.memref()->getType().cast<MemRefType>(), + loadOperands.memref(), loadOperands.indices()); + rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, + /*memory_access =*/nullptr, + /*alignment =*/nullptr); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + if (returnOp.getNumOperands()) { + return matchFailure(); + } + rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + SelectOpOperandAdaptor selectOperands(operands); + rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(), + selectOperands.true_value(), + selectOperands.false_value()); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + StoreOpOperandAdaptor storeOperands(operands); + auto storePtr = + getElementPtr(rewriter, typeConverter, storeOp.getLoc(), + storeOp.memref()->getType().cast<MemRefType>(), + storeOperands.memref(), storeOperands.indices()); + rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, + storeOperands.value(), + /*memory_access =*/nullptr, + /*alignment =*/nullptr); + return matchSuccess(); +} + +namespace { +/// Import the Standard Ops to SPIR-V Patterns. +#include "StandardToSPIRV.cpp.inc" +} // namespace + +namespace mlir { +void populateStandardToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + // Add patterns that lower operations into SPIR-V dialect. + populateWithGenerated(context, &patterns); + patterns.insert<ConstantIndexOpConversion, CmpIOpConversion, + IntegerOpConversion<AddIOp, spirv::IAddOp>, + IntegerOpConversion<MulIOp, spirv::IMulOp>, + IntegerOpConversion<SignedDivIOp, spirv::SDivOp>, + IntegerOpConversion<SignedRemIOp, spirv::SModOp>, + IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion, + ReturnOpConversion, SelectOpConversion, StoreOpConversion>( + context, typeConverter); +} +} // namespace mlir diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp new file mode 100644 index 00000000000..52456b6e46d --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -0,0 +1,89 @@ +//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard ops into the SPIR-V +// ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +/// A simple pattern for rewriting function signature to convert arguments of +/// functions to be of valid SPIR-V types. +class FuncOpConversion final : public SPIRVOpLowering<FuncOp> { +public: + using SPIRVOpLowering<FuncOp>::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// A pass converting MLIR Standard operations into the SPIR-V dialect. +class ConvertStandardToSPIRVPass + : public ModulePass<ConvertStandardToSPIRVPass> { + void runOnModule() override; +}; +} // namespace + +PatternMatchResult +FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + auto fnType = funcOp.getType(); + if (fnType.getNumResults()) { + return matchFailure(); + } + + TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); + { + for (auto argType : enumerate(funcOp.getType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + signatureConverter.addInputs(argType.index(), convertedType); + } + } + + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + }); + return matchSuccess(); +} + +void ConvertStandardToSPIRVPass::runOnModule() { + OwningRewritePatternList patterns; + auto context = &getContext(); + auto module = getModule(); + + SPIRVTypeConverter typeConverter; + populateStandardToSPIRVPatterns(context, typeConverter, patterns); + patterns.insert<FuncOpConversion>(context, typeConverter); + ConversionTarget target(*(module.getContext())); + target.addLegalDialect<spirv::SPIRVDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); + + if (failed(applyPartialConversion(module, target, patterns))) { + return signalPassFailure(); + } +} + +std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() { + return std::make_unique<ConvertStandardToSPIRVPass>(); +} + +static PassRegistration<ConvertStandardToSPIRVPass> + pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect"); diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp new file mode 100644 index 00000000000..a658356f76c --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -0,0 +1,181 @@ +//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes operations before the conversion to SPIR-V +// dialect to handle ops that cannot be lowered directly. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Merges subview operation with load operation. +class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> { +public: + using OpRewritePattern<LoadOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges subview operation with store operation. +class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> { +public: + using OpRewritePattern<StoreOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Utility functions for op legalization. +//===----------------------------------------------------------------------===// + +/// Given the 'indices' of an load/store operation where the memref is a result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to +/// memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// +/// could be folded into +/// +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndices(Location loc, PatternRewriter &rewriter, + SubViewOp subViewOp, ValueRange indices, + SmallVectorImpl<Value> &sourceIndices) { + // TODO: Aborting when the offsets are static. There might be a way to fold + // the subview op with load even if the offsets have been canonicalized + // away. + if (subViewOp.getNumOffsets() == 0) + return failure(); + + ValueRange opOffsets = subViewOp.offsets(); + SmallVector<Value, 2> opStrides; + if (subViewOp.getNumStrides()) { + // If the strides are dynamic, get the stride operands. + opStrides = llvm::to_vector<2>(subViewOp.strides()); + } else { + // When static, the stride operands can be retrieved by taking the strides + // of the result of the subview op, and dividing the strides of the base + // memref. + SmallVector<int64_t, 2> staticStrides; + if (failed(subViewOp.getStaticStrides(staticStrides))) { + return failure(); + } + opStrides.reserve(opOffsets.size()); + for (auto stride : staticStrides) { + auto constValAttr = rewriter.getIntegerAttr( + IndexType::get(rewriter.getContext()), stride); + opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr)); + } + } + assert(opOffsets.size() == opStrides.size()); + + // New indices for the load are the current indices * subview_stride + + // subview_offset. + assert(indices.size() == opStrides.size()); + sourceIndices.resize(indices.size()); + for (auto index : llvm::enumerate(indices)) { + auto offset = opOffsets[index.index()]; + auto stride = opStrides[index.index()]; + auto mul = rewriter.create<MulIOp>(loc, index.value(), stride); + sourceIndices[index.index()] = + rewriter.create<AddIOp>(loc, offset, mul).getResult(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and LoadOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector<Value, 4> sourceIndices; + if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, + loadOp.indices(), sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(), + sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Folding SubViewOp and StoreOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const { + auto subViewOp = + dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp()); + if (!subViewOp) { + return matchFailure(); + } + SmallVector<Value, 4> sourceIndices; + if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, + storeOp.indices(), sourceIndices))) + return matchFailure(); + + rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(), + subViewOp.source(), sourceIndices); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Hook for adding patterns. +//===----------------------------------------------------------------------===// + +void mlir::populateStdLegalizationPatternsForSPIRVLowering( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context); +} + +//===----------------------------------------------------------------------===// +// Pass for testing just the legalization patterns. +//===----------------------------------------------------------------------===// + +namespace { +struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> { + void runOnOperation() override; +}; +} // namespace + +void SPIRVLegalization::runOnOperation() { + OwningRewritePatternList patterns; + auto *context = &getContext(); + populateStdLegalizationPatternsForSPIRVLowering(context, patterns); + applyPatternsGreedily(getOperation()->getRegions(), patterns); +} + +std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() { + return std::make_unique<SPIRVLegalization>(); +} + +static PassRegistration<SPIRVLegalization> + pass("legalize-std-for-spirv", "Legalize standard ops for SPIR-V lowering"); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td new file mode 100644 index 00000000000..6f3a6a82476 --- /dev/null +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -0,0 +1,35 @@ +//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==// + +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines Patterns to lower standard ops to SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_TD +#define MLIR_CONVERSION_STANDARDTOSPIRV_TD + +include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/SPIRV/SPIRVOps.td" + +class BinaryOpPattern<Op src, Op tgt> : + Pat<(src SPV_ScalarOrVector:$l, SPV_ScalarOrVector:$r), + (tgt $l, $r)>; + +def : BinaryOpPattern<AddFOp, SPV_FAddOp>; +def : BinaryOpPattern<DivFOp, SPV_FDivOp>; +def : BinaryOpPattern<MulFOp, SPV_FMulOp>; +def : BinaryOpPattern<RemFOp, SPV_FRemOp>; +def : BinaryOpPattern<SubFOp, SPV_FSubOp>; + +// Constant Op +// TODO(ravishankarm): Handle lowering other constant types. +def : Pat<(ConstantOp:$result $valueAttr), + (SPV_ConstantOp $valueAttr), + [(SPV_ScalarOrVector $result)]>; + +#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt new file mode 100644 index 00000000000..2aaec68f6c4 --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRVectorToLLVM + ConvertVectorToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRVectorToLLVM ${LIBS}) +target_link_libraries(MLIRVectorToLLVM ${LIBS}) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp new file mode 100644 index 00000000000..b48930c4dda --- /dev/null +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -0,0 +1,766 @@ +//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; + +template <typename T> +static LLVM::LLVMType getPtrToElementType(T containerType, + LLVMTypeConverter &lowering) { + return lowering.convertType(containerType.getElementType()) + .template cast<LLVM::LLVMType>() + .getPointerTo(); +} + +// Helper to reduce vector type by one rank at front. +static VectorType reducedVectorTypeFront(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); +} + +// Helper to reduce vector type by *all* but one rank at back. +static VectorType reducedVectorTypeBack(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().take_back(), tp.getElementType()); +} + +// Helper that picks the proper sequence for inserting. +static Value insertOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value val1, + Value val2, Type llvmType, int64_t rank, int64_t pos) { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, + constant); + } + return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, + rewriter.getI64ArrayAttr(pos)); +} + +// Helper that picks the proper sequence for extracting. +static Value extractOne(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering, Location loc, Value val, + Type llvmType, int64_t rank, int64_t pos) { + if (rank == 1) { + auto idxType = rewriter.getIndexType(); + auto constant = rewriter.create<LLVM::ConstantOp>( + loc, lowering.convertType(idxType), + rewriter.getIntegerAttr(idxType, pos)); + return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, + constant); + } + return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, + rewriter.getI64ArrayAttr(pos)); +} + +class VectorBroadcastOpConversion : public LLVMOpLowering { +public: + explicit VectorBroadcastOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto broadcastOp = cast<vector::BroadcastOp>(op); + VectorType dstVectorType = broadcastOp.getVectorType(); + if (lowering.convertType(dstVectorType) == nullptr) + return matchFailure(); + // Rewrite when the full vector type can be lowered (which + // implies all 'reduced' types can be lowered too). + auto adaptor = vector::BroadcastOpOperandAdaptor(operands); + VectorType srcVectorType = + broadcastOp.getSourceType().dyn_cast<VectorType>(); + rewriter.replaceOp( + op, expandRanks(adaptor.source(), // source value to be expanded + op->getLoc(), // location of original broadcast + srcVectorType, dstVectorType, rewriter)); + return matchSuccess(); + } + +private: + // Expands the given source value over all the ranks, as defined + // by the source and destination type (a null source type denotes + // expansion from a scalar value into a vector). + // + // TODO(ajcbik): consider replacing this one-pattern lowering + // with a two-pattern lowering using other vector + // ops once all insert/extract/shuffle operations + // are available with lowering implemention. + // + Value expandRanks(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, + ConversionPatternRewriter &rewriter) const { + assert((dstVectorType != nullptr) && "invalid result type in broadcast"); + // Determine rank of source and destination. + int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; + int64_t dstRank = dstVectorType.getRank(); + int64_t curDim = dstVectorType.getDimSize(0); + if (srcRank < dstRank) + // Duplicate this rank. + return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, + curDim, rewriter); + // If all trailing dimensions are the same, the broadcast consists of + // simply passing through the source value and we are done. Otherwise, + // any non-matching dimension forces a stretch along this rank. + assert((srcVectorType != nullptr) && (srcRank > 0) && + (srcRank == dstRank) && "invalid rank in broadcast"); + for (int64_t r = 0; r < dstRank; r++) { + if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { + return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, + curDim, rewriter); + } + } + return value; + } + + // Picks the best way to duplicate a single rank. For the 1-D case, a + // single insert-elt/shuffle is the most efficient expansion. For higher + // dimensions, however, we need dim x insert-values on a new broadcast + // with one less leading dimension, which will be lowered "recursively" + // to matching LLVM IR. + // For example: + // v = broadcast s : f32 to vector<4x2xf32> + // becomes: + // x = broadcast s : f32 to vector<2xf32> + // v = [x,x,x,x] + // becomes: + // x = [s,s] + // v = [x,x,x,x] + Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { + Type llvmType = lowering.convertType(dstVectorType); + assert((llvmType != nullptr) && "unlowerable vector type"); + if (rank == 1) { + Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); + Value expand = + insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); + SmallVector<int32_t, 4> zeroValues(dim, 0); + return rewriter.create<LLVM::ShuffleVectorOp>( + loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); + } + Value expand = expandRanks(value, loc, srcVectorType, + reducedVectorTypeFront(dstVectorType), rewriter); + Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + for (int64_t d = 0; d < dim; ++d) { + result = + insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); + } + return result; + } + + // Picks the best way to stretch a single rank. For the 1-D case, a + // single insert-elt/shuffle is the most efficient expansion when at + // a stretch. Otherwise, every dimension needs to be expanded + // individually and individually inserted in the resulting vector. + // For example: + // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> + // becomes: + // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> + // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> + // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> + // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> + // v = [a,b,c,d] + // becomes: + // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> + // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> + // a = [x, y] + // etc. + Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, + VectorType dstVectorType, int64_t rank, int64_t dim, + ConversionPatternRewriter &rewriter) const { + Type llvmType = lowering.convertType(dstVectorType); + assert((llvmType != nullptr) && "unlowerable vector type"); + Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType); + bool atStretch = dim != srcVectorType.getDimSize(0); + if (rank == 1) { + assert(atStretch); + Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); + Value one = + extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); + Value expand = + insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); + SmallVector<int32_t, 4> zeroValues(dim, 0); + return rewriter.create<LLVM::ShuffleVectorOp>( + loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); + } + VectorType redSrcType = reducedVectorTypeFront(srcVectorType); + VectorType redDstType = reducedVectorTypeFront(dstVectorType); + Type redLlvmType = lowering.convertType(redSrcType); + for (int64_t d = 0; d < dim; ++d) { + int64_t pos = atStretch ? 0 : d; + Value one = + extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); + Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); + result = + insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); + } + return result; + } +}; + +class VectorShuffleOpConversion : public LLVMOpLowering { +public: + explicit VectorShuffleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::ShuffleOpOperandAdaptor(operands); + auto shuffleOp = cast<vector::ShuffleOp>(op); + auto v1Type = shuffleOp.getV1VectorType(); + auto v2Type = shuffleOp.getV2VectorType(); + auto vectorType = shuffleOp.getVectorType(); + Type llvmType = lowering.convertType(vectorType); + auto maskArrayAttr = shuffleOp.mask(); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + // Get rank and dimension sizes. + int64_t rank = vectorType.getRank(); + assert(v1Type.getRank() == rank); + assert(v2Type.getRank() == rank); + int64_t v1Dim = v1Type.getDimSize(0); + + // For rank 1, where both operands have *exactly* the same vector type, + // there is direct shuffle support in LLVM. Use it! + if (rank == 1 && v1Type == v2Type) { + Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( + loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); + rewriter.replaceOp(op, shuffle); + return matchSuccess(); + } + + // For all other cases, insert the individual values individually. + Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); + int64_t insPos = 0; + for (auto en : llvm::enumerate(maskArrayAttr)) { + int64_t extPos = en.value().cast<IntegerAttr>().getInt(); + Value value = adaptor.v1(); + if (extPos >= v1Dim) { + extPos -= v1Dim; + value = adaptor.v2(); + } + Value extract = + extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); + insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, + rank, insPos++); + } + rewriter.replaceOp(op, insert); + return matchSuccess(); + } +}; + +class VectorExtractElementOpConversion : public LLVMOpLowering { +public: + explicit VectorExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto extractEltOp = cast<vector::ExtractElementOp>(op); + auto vectorType = extractEltOp.getVectorType(); + auto llvmType = lowering.convertType(vectorType.getElementType()); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( + op, llvmType, adaptor.vector(), adaptor.position()); + return matchSuccess(); + } +}; + +class VectorExtractOpConversion : public LLVMOpLowering { +public: + explicit VectorExtractOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::ExtractOpOperandAdaptor(operands); + auto extractOp = cast<vector::ExtractOp>(op); + auto vectorType = extractOp.getVectorType(); + auto resultType = extractOp.getResult()->getType(); + auto llvmResultType = lowering.convertType(resultType); + auto positionArrayAttr = extractOp.position(); + + // Bail if result type cannot be lowered. + if (!llvmResultType) + return matchFailure(); + + // One-shot extraction of vector from array (only requires extractvalue). + if (resultType.isa<VectorType>()) { + Value extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmResultType, adaptor.vector(), positionArrayAttr); + rewriter.replaceOp(op, extracted); + return matchSuccess(); + } + + // Potential extraction of 1-D vector from array. + auto *context = op->getContext(); + Value extracted = adaptor.vector(); + auto positionAttrs = positionArrayAttr.getValue(); + if (positionAttrs.size() > 1) { + auto oneDVectorType = reducedVectorTypeBack(vectorType); + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs); + } + + // Remaining extraction of element from 1-D LLVM vector + auto position = positionAttrs.back().cast<IntegerAttr>(); + auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); + extracted = + rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); + rewriter.replaceOp(op, extracted); + + return matchSuccess(); + } +}; + +class VectorInsertElementOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::InsertElementOpOperandAdaptor(operands); + auto insertEltOp = cast<vector::InsertElementOp>(op); + auto vectorType = insertEltOp.getDestVectorType(); + auto llvmType = lowering.convertType(vectorType); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( + op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); + return matchSuccess(); + } +}; + +class VectorInsertOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::InsertOpOperandAdaptor(operands); + auto insertOp = cast<vector::InsertOp>(op); + auto sourceType = insertOp.getSourceType(); + auto destVectorType = insertOp.getDestVectorType(); + auto llvmResultType = lowering.convertType(destVectorType); + auto positionArrayAttr = insertOp.position(); + + // Bail if result type cannot be lowered. + if (!llvmResultType) + return matchFailure(); + + // One-shot insertion of a vector into an array (only requires insertvalue). + if (sourceType.isa<VectorType>()) { + Value inserted = rewriter.create<LLVM::InsertValueOp>( + loc, llvmResultType, adaptor.dest(), adaptor.source(), + positionArrayAttr); + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } + + // Potential extraction of 1-D vector from array. + auto *context = op->getContext(); + Value extracted = adaptor.dest(); + auto positionAttrs = positionArrayAttr.getValue(); + auto position = positionAttrs.back().cast<IntegerAttr>(); + auto oneDVectorType = destVectorType; + if (positionAttrs.size() > 1) { + oneDVectorType = reducedVectorTypeBack(destVectorType); + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs); + } + + // Insertion of an element into a 1-D LLVM vector. + auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); + Value inserted = rewriter.create<LLVM::InsertElementOp>( + loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), + constant); + + // Potential insertion of resulting 1-D vector into array. + if (positionAttrs.size() > 1) { + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, + adaptor.dest(), inserted, + nMinusOnePositionAttrs); + } + + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } +}; + +class VectorOuterProductOpConversion : public LLVMOpLowering { +public: + explicit VectorOuterProductOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::OuterProductOpOperandAdaptor(operands); + auto *ctx = op->getContext(); + auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); + auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); + auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); + auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); + auto llvmArrayOfVectType = lowering.convertType( + cast<vector::OuterProductOp>(op).getResult()->getType()); + Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); + Value a = adaptor.lhs(), b = adaptor.rhs(); + Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); + SmallVector<Value, 8> lhs, accs; + lhs.reserve(rankLHS); + accs.reserve(rankLHS); + for (unsigned d = 0, e = rankLHS; d < e; ++d) { + // shufflevector explicitly requires i32. + auto attr = rewriter.getI32IntegerAttr(d); + SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); + auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); + Value aD = nullptr, accD = nullptr; + // 1. Broadcast the element a[d] into vector aD. + aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); + // 2. If acc is present, extract 1-d vector acc[d] into accD. + if (acc) + accD = rewriter.create<LLVM::ExtractValueOp>( + loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); + // 3. Compute aD outer b (plus accD, if relevant). + Value aOuterbD = + accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) + .getResult() + : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); + // 4. Insert as value `d` in the descriptor. + desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, + desc, aOuterbD, + rewriter.getI64ArrayAttr(d)); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +class VectorTypeCastOpConversion : public LLVMOpLowering { +public: + explicit VectorTypeCastOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); + MemRefType sourceMemRefType = + castOp.getOperand()->getType().cast<MemRefType>(); + MemRefType targetMemRefType = + castOp.getResult()->getType().cast<MemRefType>(); + + // Only static shape casts supported atm. + if (!sourceMemRefType.hasStaticShape() || + !targetMemRefType.hasStaticShape()) + return matchFailure(); + + auto llvmSourceDescriptorTy = + operands[0]->getType().dyn_cast<LLVM::LLVMType>(); + if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) + return matchFailure(); + MemRefDescriptor sourceMemRef(operands[0]); + + auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) + .dyn_cast_or_null<LLVM::LLVMType>(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return matchFailure(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = + getStridesAndOffset(sourceMemRefType, strides, offset); + bool isContiguous = (strides.back() == 1); + if (isContiguous) { + auto sizes = sourceMemRefType.getShape(); + for (int index = 0, e = strides.size() - 2; index < e; ++index) { + if (strides[index] != strides[index + 1] * sizes[index + 1]) { + isContiguous = false; + break; + } + } + } + // Only contiguous source tensors supported atm. + if (failed(successStrides) || !isContiguous) + return matchFailure(); + + auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + + // Create descriptor. + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); + // Set allocated ptr. + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); + allocated = + rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); + desc.setAlignedPtr(rewriter, loc, ptr); + // Fill offset 0. + auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); + auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); + desc.setOffset(rewriter, loc, zero); + + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { + int64_t index = indexedSize.index(); + auto sizeAttr = + rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); + auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); + desc.setSize(rewriter, loc, index, size); + auto strideAttr = + rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); + auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); + desc.setStride(rewriter, loc, index, stride); + } + + rewriter.replaceOp(op, {desc}); + return matchSuccess(); + } +}; + +class VectorPrintOpConversion : public LLVMOpLowering { +public: + explicit VectorPrintOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::PrintOp::getOperationName(), context, + typeConverter) {} + + // Proof-of-concept lowering implementation that relies on a small + // runtime support library, which only needs to provide a few + // printing methods (single value for all data types, opening/closing + // bracket, comma, newline). The lowering fully unrolls a vector + // in terms of these elementary printing operations. The advantage + // of this approach is that the library can remain unaware of all + // low-level implementation details of vectors while still supporting + // output of any shaped and dimensioned vector. Due to full unrolling, + // this approach is less suited for very large vectors though. + // + // TODO(ajcbik): rely solely on libc in future? something else? + // + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto printOp = cast<vector::PrintOp>(op); + auto adaptor = vector::PrintOpOperandAdaptor(operands); + Type printType = printOp.getPrintType(); + + if (lowering.convertType(printType) == nullptr) + return matchFailure(); + + // Make sure element type has runtime support (currently just Float/Double). + VectorType vectorType = printType.dyn_cast<VectorType>(); + Type eltType = vectorType ? vectorType.getElementType() : printType; + int64_t rank = vectorType ? vectorType.getRank() : 0; + Operation *printer; + if (eltType.isF32()) + printer = getPrintFloat(op); + else if (eltType.isF64()) + printer = getPrintDouble(op); + else + return matchFailure(); + + // Unroll vector into elementary print calls. + emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); + emitCall(rewriter, op->getLoc(), getPrintNewline(op)); + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, + Value value, VectorType vectorType, Operation *printer, + int64_t rank) const { + Location loc = op->getLoc(); + if (rank == 0) { + emitCall(rewriter, loc, printer, value); + return; + } + + emitCall(rewriter, loc, getPrintOpen(op)); + Operation *printComma = getPrintComma(op); + int64_t dim = vectorType.getDimSize(0); + for (int64_t d = 0; d < dim; ++d) { + auto reducedType = + rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; + auto llvmType = lowering.convertType( + rank > 1 ? reducedType : vectorType.getElementType()); + Value nestedVal = + extractOne(rewriter, lowering, loc, value, llvmType, rank, d); + emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); + if (d != dim - 1) + emitCall(rewriter, loc, printComma); + } + emitCall(rewriter, loc, getPrintClose(op)); + } + + // Helper to emit a call. + static void emitCall(ConversionPatternRewriter &rewriter, Location loc, + Operation *ref, ValueRange params = ValueRange()) { + rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, + rewriter.getSymbolRefAttr(ref), params); + } + + // Helper for printer method declaration (first hit) and lookup. + static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, + StringRef name, ArrayRef<LLVM::LLVMType> params) { + auto module = op->getParentOfType<ModuleOp>(); + auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); + if (func) + return func; + OpBuilder moduleBuilder(module.getBodyRegion()); + return moduleBuilder.create<LLVM::LLVMFuncOp>( + op->getLoc(), name, + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), + params, /*isVarArg=*/false)); + } + + // Helpers for method names. + Operation *getPrintFloat(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_f32", + LLVM::LLVMType::getFloatTy(dialect)); + } + Operation *getPrintDouble(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_f64", + LLVM::LLVMType::getDoubleTy(dialect)); + } + Operation *getPrintOpen(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_open", {}); + } + Operation *getPrintClose(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_close", {}); + } + Operation *getPrintComma(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_comma", {}); + } + Operation *getPrintNewline(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_newline", {}); + } +}; + +/// Populate the given list with patterns that convert from Vector to LLVM. +void mlir::populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, + VectorExtractElementOpConversion, VectorExtractOpConversion, + VectorInsertElementOpConversion, VectorInsertOpConversion, + VectorOuterProductOpConversion, VectorTypeCastOpConversion, + VectorPrintOpConversion>(converter.getDialect()->getContext(), + converter); +} + +namespace { +struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { + void runOnModule() override; +}; +} // namespace + +void LowerVectorToLLVMPass::runOnModule() { + // Convert to the LLVM IR dialect using the converter defined above. + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateVectorToLLVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect<LLVM::LLVMDialect>(); + target.addDynamicallyLegalOp<FuncOp>( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed( + applyPartialConversion(getModule(), target, patterns, &converter))) { + signalPassFailure(); + } +} + +OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { + return new LowerVectorToLLVMPass(); +} + +static PassRegistration<LowerVectorToLLVMPass> + pass("convert-vector-to-llvm", + "Lower the operations from the vector dialect into the LLVM dialect"); diff --git a/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt b/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt new file mode 100644 index 00000000000..e213bc9bcce --- /dev/null +++ b/mlir/lib/Conversion/VectorToLoops/CMakeLists.txt @@ -0,0 +1,15 @@ +add_llvm_library(MLIRVectorToLoops + ConvertVectorToLoops.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLoops +) +set(LIBS + MLIRLLVMIR + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRVectorToLoops ${LIBS}) +target_link_libraries(MLIRVectorToLoops ${LIBS}) diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp new file mode 100644 index 00000000000..3ed031b985a --- /dev/null +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -0,0 +1,358 @@ +//===- VectorToLoops.cpp - Conversion from Vector to mix of Loops and Std -===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-dependent lowering of vector transfer operations. +// +//===----------------------------------------------------------------------===// + +#include <type_traits> + +#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +using namespace mlir; +using vector::TransferReadOp; +using vector::TransferWriteOp; + +/// Analyzes the `transfer` to find an access dimension along the fastest remote +/// MemRef dimension. If such a dimension with coalescing properties is found, +/// `pivs` and `vectorView` are swapped so that the invocation of +/// LoopNestBuilder captures it in the innermost loop. +template <typename TransferOpTy> +static void coalesceCopy(TransferOpTy transfer, + SmallVectorImpl<edsc::ValueHandle *> *pivs, + edsc::VectorView *vectorView) { + // rank of the remote memory access, coalescing behavior occurs on the + // innermost memory dimension. + auto remoteRank = transfer.getMemRefType().getRank(); + // Iterate over the results expressions of the permutation map to determine + // the loop order for creating pointwise copies between remote and local + // memories. + int coalescedIdx = -1; + auto exprs = transfer.permutation_map().getResults(); + for (auto en : llvm::enumerate(exprs)) { + auto dim = en.value().template dyn_cast<AffineDimExpr>(); + if (!dim) { + continue; + } + auto memRefDim = dim.getPosition(); + if (memRefDim == remoteRank - 1) { + // memRefDim has coalescing properties, it should be swapped in the last + // position. + assert(coalescedIdx == -1 && "Unexpected > 1 coalesced indices"); + coalescedIdx = en.index(); + } + } + if (coalescedIdx >= 0) { + std::swap(pivs->back(), (*pivs)[coalescedIdx]); + vectorView->swapRanges(pivs->size() - 1, coalescedIdx); + } +} + +/// Emits remote memory accesses that are clipped to the boundaries of the +/// MemRef. +template <typename TransferOpTy> +static SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer, + edsc::MemRefView &view, + ArrayRef<edsc::IndexHandle> ivs) { + using namespace mlir::edsc; + using namespace edsc::op; + using edsc::intrinsics::select; + + IndexHandle zero(index_t(0)), one(index_t(1)); + SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices()); + SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs( + memRefAccess.size(), edsc::IndexHandle()); + + // Indices accessing to remote memory are clipped and their expressions are + // returned in clippedScalarAccessExprs. + for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size(); + ++memRefDim) { + // Linear search on a small number of entries. + int loopIndex = -1; + auto exprs = transfer.permutation_map().getResults(); + for (auto en : llvm::enumerate(exprs)) { + auto expr = en.value(); + auto dim = expr.template dyn_cast<AffineDimExpr>(); + // Sanity check. + assert( + (dim || expr.template cast<AffineConstantExpr>().getValue() == 0) && + "Expected dim or 0 in permutationMap"); + if (dim && memRefDim == dim.getPosition()) { + loopIndex = en.index(); + break; + } + } + + // We cannot distinguish atm between unrolled dimensions that implement + // the "always full" tile abstraction and need clipping from the other + // ones. So we conservatively clip everything. + auto N = view.ub(memRefDim); + auto i = memRefAccess[memRefDim]; + if (loopIndex < 0) { + auto N_minus_1 = N - one; + auto select_1 = select(i < N, i, N_minus_1); + clippedScalarAccessExprs[memRefDim] = select(i < zero, zero, select_1); + } else { + auto ii = ivs[loopIndex]; + auto i_plus_ii = i + ii; + auto N_minus_1 = N - one; + auto select_1 = select(i_plus_ii < N, i_plus_ii, N_minus_1); + clippedScalarAccessExprs[memRefDim] = + select(i_plus_ii < zero, zero, select_1); + } + } + + return clippedScalarAccessExprs; +} + +namespace { + +using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>; + +/// Implements lowering of TransferReadOp and TransferWriteOp to a +/// proper abstraction for the hardware. +/// +/// For now, we only emit a simple loop nest that performs clipped pointwise +/// copies from a remote to a locally allocated memory. +/// +/// Consider the case: +/// +/// ```mlir +/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into +/// // vector<32x256xf32> and pad with %f0 to handle the boundary case: +/// %f0 = constant 0.0f : f32 +/// loop.for %i0 = 0 to %0 { +/// loop.for %i1 = 0 to %1 step %c256 { +/// loop.for %i2 = 0 to %2 step %c32 { +/// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0 +/// {permutation_map: (d0, d1, d2) -> (d2, d1)} : +/// memref<?x?x?xf32>, vector<32x256xf32> +/// }}} +/// ``` +/// +/// The rewriters construct loop and indices that access MemRef A in a pattern +/// resembling the following (while guaranteeing an always full-tile +/// abstraction): +/// +/// ```mlir +/// loop.for %d2 = 0 to %c256 { +/// loop.for %d1 = 0 to %c32 { +/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 +/// %tmp[%d2, %d1] = %s +/// } +/// } +/// ``` +/// +/// In the current state, only a clipping transfer is implemented by `clip`, +/// which creates individual indexing expressions of the form: +/// +/// ```mlir-dsc +/// auto condMax = i + ii < N; +/// auto max = select(condMax, i + ii, N - one) +/// auto cond = i + ii < zero; +/// select(cond, zero, max); +/// ``` +/// +/// In the future, clipping should not be the only way and instead we should +/// load vectors + mask them. Similarly on the write side, load/mask/store for +/// implementing RMW behavior. +/// +/// Lowers TransferOp into a combination of: +/// 1. local memory allocation; +/// 2. perfect loop nest over: +/// a. scalar load/stores from local buffers (viewed as a scalar memref); +/// a. scalar store/load to original memref (with clipping). +/// 3. vector_load/store +/// 4. local memory deallocation. +/// Minor variations occur depending on whether a TransferReadOp or +/// a TransferWriteOp is rewritten. +template <typename TransferOpTy> +struct VectorTransferRewriter : public RewritePattern { + explicit VectorTransferRewriter(MLIRContext *context) + : RewritePattern(TransferOpTy::getOperationName(), 1, context) {} + + /// Used for staging the transfer in a local scalar buffer. + MemRefType tmpMemRefType(TransferOpTy transfer) const { + auto vectorType = transfer.getVectorType(); + return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), + {}, 0); + } + + /// Performs the rewrite. + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +/// Lowers TransferReadOp into a combination of: +/// 1. local memory allocation; +/// 2. perfect loop nest over: +/// a. scalar load from local buffers (viewed as a scalar memref); +/// a. scalar store to original memref (with clipping). +/// 3. vector_load from local buffer (viewed as a memref<1 x vector>); +/// 4. local memory deallocation. +/// +/// Lowers the data transfer part of a TransferReadOp while ensuring no +/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by +/// clipping. This means that a given value in memory can be read multiple +/// times and concurrently. +/// +/// Important notes about clipping and "full-tiles only" abstraction: +/// ================================================================= +/// When using clipping for dealing with boundary conditions, the same edge +/// value will appear multiple times (a.k.a edge padding). This is fine if the +/// subsequent vector operations are all data-parallel but **is generally +/// incorrect** in the presence of reductions or extract operations. +/// +/// More generally, clipping is a scalar abstraction that is expected to work +/// fine as a baseline for CPUs and GPUs but not for vector_load and DMAs. +/// To deal with real vector_load and DMAs, a "padded allocation + view" +/// abstraction with the ability to read out-of-memref-bounds (but still within +/// the allocated region) is necessary. +/// +/// Whether using scalar loops or vector_load/DMAs to perform the transfer, +/// junk values will be materialized in the vectors and generally need to be +/// filtered out and replaced by the "neutral element". This neutral element is +/// op-dependent so, in the future, we expect to create a vector filter and +/// apply it to a splatted constant vector with the proper neutral element at +/// each ssa-use. This filtering is not necessary for pure data-parallel +/// operations. +/// +/// In the case of vector_store/DMAs, Read-Modify-Write will be required, which +/// also have concurrency implications. Note that by using clipped scalar stores +/// in the presence of data-parallel only operations, we generate code that +/// writes the same value multiple time on the edge locations. +/// +/// TODO(ntv): implement alternatives to clipping. +/// TODO(ntv): support non-data-parallel operations. + +/// Performs the rewrite. +template <> +PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + using namespace mlir::edsc; + using namespace mlir::edsc::op; + using namespace mlir::edsc::intrinsics; + using IndexedValue = + TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>; + + TransferReadOp transfer = cast<TransferReadOp>(op); + + // 1. Setup all the captures. + ScopedContext scope(rewriter, transfer.getLoc()); + IndexedValue remote(transfer.memref()); + MemRefView view(transfer.memref()); + VectorView vectorView(transfer.vector()); + SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank()); + SmallVector<ValueHandle *, 8> pivs = + makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); + coalesceCopy(transfer, &pivs, &vectorView); + + auto lbs = vectorView.getLbs(); + auto ubs = vectorView.getUbs(); + SmallVector<ValueHandle, 8> steps; + steps.reserve(vectorView.getSteps().size()); + for (auto step : vectorView.getSteps()) + steps.push_back(constant_index(step)); + + // 2. Emit alloc-copy-load-dealloc. + ValueHandle tmp = alloc(tmpMemRefType(transfer)); + IndexedValue local(tmp); + ValueHandle vec = vector_type_cast(tmp); + LoopNestBuilder(pivs, lbs, ubs, steps)([&] { + // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). + local(ivs) = remote(clip(transfer, view, ivs)); + }); + ValueHandle vectorValue = std_load(vec); + (dealloc(tmp)); // vexing parse + + // 3. Propagate. + rewriter.replaceOp(op, vectorValue.getValue()); + return matchSuccess(); +} + +/// Lowers TransferWriteOp into a combination of: +/// 1. local memory allocation; +/// 2. vector_store to local buffer (viewed as a memref<1 x vector>); +/// 3. perfect loop nest over: +/// a. scalar load from local buffers (viewed as a scalar memref); +/// a. scalar store to original memref (with clipping). +/// 4. local memory deallocation. +/// +/// More specifically, lowers the data transfer part while ensuring no +/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by +/// clipping. This means that a given value in memory can be written to multiple +/// times and concurrently. +/// +/// See `Important notes about clipping and full-tiles only abstraction` in the +/// description of `readClipped` above. +/// +/// TODO(ntv): implement alternatives to clipping. +/// TODO(ntv): support non-data-parallel operations. +template <> +PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + using namespace mlir::edsc; + using namespace mlir::edsc::op; + using namespace mlir::edsc::intrinsics; + using IndexedValue = + TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>; + + TransferWriteOp transfer = cast<TransferWriteOp>(op); + + // 1. Setup all the captures. + ScopedContext scope(rewriter, transfer.getLoc()); + IndexedValue remote(transfer.memref()); + MemRefView view(transfer.memref()); + ValueHandle vectorValue(transfer.vector()); + VectorView vectorView(transfer.vector()); + SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank()); + SmallVector<ValueHandle *, 8> pivs = + makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); + coalesceCopy(transfer, &pivs, &vectorView); + + auto lbs = vectorView.getLbs(); + auto ubs = vectorView.getUbs(); + SmallVector<ValueHandle, 8> steps; + steps.reserve(vectorView.getSteps().size()); + for (auto step : vectorView.getSteps()) + steps.push_back(constant_index(step)); + + // 2. Emit alloc-store-copy-dealloc. + ValueHandle tmp = alloc(tmpMemRefType(transfer)); + IndexedValue local(tmp); + ValueHandle vec = vector_type_cast(tmp); + std_store(vectorValue, vec); + LoopNestBuilder(pivs, lbs, ubs, steps)([&] { + // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). + remote(clip(transfer, view, ivs)) = local(ivs); + }); + (dealloc(tmp)); // vexing parse... + + rewriter.eraseOp(op); + return matchSuccess(); +} + +} // namespace + +void mlir::populateVectorToAffineLoopsConversionPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert<VectorTransferRewriter<vector::TransferReadOp>, + VectorTransferRewriter<vector::TransferWriteOp>>(context); +} |