//===- AffineToStandard.cpp - Lower affine constructs to primitives -------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file lowers affine constructs (If and For statements, AffineApply // operations) within a function into their standard If and For equivalent ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" using namespace mlir; namespace { // Visit affine expressions recursively and build the sequence of operations // that correspond to it. Visitation functions return an Value of the // expression subtree they visited or `nullptr` on error. class AffineApplyExpander : public AffineExprVisitor { public: // This internal class expects arguments to be non-null, checks must be // performed at the call site. AffineApplyExpander(OpBuilder &builder, ArrayRef dimValues, ArrayRef symbolValues, Location loc) : builder(builder), dimValues(dimValues), symbolValues(symbolValues), loc(loc) {} template Value buildBinaryExpr(AffineBinaryOpExpr expr) { auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) return nullptr; auto op = builder.create(loc, lhs, rhs); return op.getResult(); } Value visitAddExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } Value visitMulExpr(AffineBinaryOpExpr expr) { return buildBinaryExpr(expr); } // Euclidean modulo operation: negative RHS is not allowed. // Remainder of the euclidean integer division is always non-negative. // // Implemented as // // a mod b = // let remainder = srem a, b; // negative = a < 0 in // select negative, remainder + b, remainder. Value visitModExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError( loc, "semi-affine expressions (modulo by non-const) are not supported"); return nullptr; } if (rhsConst.getValue() <= 0) { emitError(loc, "modulo by non-positive value is not supported"); return nullptr; } auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); Value remainder = builder.create(loc, lhs, rhs); Value zeroCst = builder.create(loc, 0); Value isRemainderNegative = builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); Value correctedRemainder = builder.create(loc, remainder, rhs); Value result = builder.create(loc, isRemainderNegative, correctedRemainder, remainder); return result; } // Floor division operation (rounds towards negative infinity). // // For positive divisors, it can be implemented without branching and with a // single division operation as // // a floordiv b = // let negative = a < 0 in // let absolute = negative ? -a - 1 : a in // let quotient = absolute / b in // negative ? -quotient - 1 : quotient Value visitFloorDivExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError( loc, "semi-affine expressions (division by non-const) are not supported"); return nullptr; } if (rhsConst.getValue() <= 0) { emitError(loc, "division by non-positive value is not supported"); return nullptr; } auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); Value zeroCst = builder.create(loc, 0); Value noneCst = builder.create(loc, -1); Value negative = builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); Value negatedDecremented = builder.create(loc, noneCst, lhs); Value dividend = builder.create(loc, negative, negatedDecremented, lhs); Value quotient = builder.create(loc, dividend, rhs); Value correctedQuotient = builder.create(loc, noneCst, quotient); Value result = builder.create(loc, negative, correctedQuotient, quotient); return result; } // Ceiling division operation (rounds towards positive infinity). // // For positive divisors, it can be implemented without branching and with a // single division operation as // // a ceildiv b = // let negative = a <= 0 in // let absolute = negative ? -a : a - 1 in // let quotient = absolute / b in // negative ? -quotient : quotient + 1 Value visitCeilDivExpr(AffineBinaryOpExpr expr) { auto rhsConst = expr.getRHS().dyn_cast(); if (!rhsConst) { emitError(loc) << "semi-affine expressions (division by non-const) are " "not supported"; return nullptr; } if (rhsConst.getValue() <= 0) { emitError(loc, "division by non-positive value is not supported"); return nullptr; } auto lhs = visit(expr.getLHS()); auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); Value zeroCst = builder.create(loc, 0); Value oneCst = builder.create(loc, 1); Value nonPositive = builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); Value negated = builder.create(loc, zeroCst, lhs); Value decremented = builder.create(loc, lhs, oneCst); Value dividend = builder.create(loc, nonPositive, negated, decremented); Value quotient = builder.create(loc, dividend, rhs); Value negatedQuotient = builder.create(loc, zeroCst, quotient); Value incrementedQuotient = builder.create(loc, quotient, oneCst); Value result = builder.create(loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { auto valueAttr = builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); auto op = builder.create(loc, builder.getIndexType(), valueAttr); return op.getResult(); } Value visitDimExpr(AffineDimExpr expr) { assert(expr.getPosition() < dimValues.size() && "affine dim position out of range"); return dimValues[expr.getPosition()]; } Value visitSymbolExpr(AffineSymbolExpr expr) { assert(expr.getPosition() < symbolValues.size() && "symbol dim position out of range"); return symbolValues[expr.getPosition()]; } private: OpBuilder &builder; ArrayRef dimValues; ArrayRef symbolValues; Location loc; }; } // namespace // Create a sequence of operations that implement the `expr` applied to the // given dimension and symbol values. mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ArrayRef dimValues, ArrayRef symbolValues) { return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); } // Create a sequence of operations that implement the `affineMap` applied to // the given `operands` (as it it were an AffineApplyOp). Optional> static expandAffineMap( OpBuilder &builder, Location loc, AffineMap affineMap, ArrayRef operands) { auto numDims = affineMap.getNumDims(); auto expanded = functional::map( [numDims, &builder, loc, operands](AffineExpr expr) { return expandAffineExpr(builder, loc, expr, operands.take_front(numDims), operands.drop_front(numDims)); }, affineMap.getResults()); if (llvm::all_of(expanded, [](Value v) { return v; })) return expanded; return None; } // Given a range of values, emit the code that reduces them with "min" or "max" // depending on the provided comparison predicate. The predicate defines which // comparison to perform, "lt" for "min", "gt" for "max" and is used for the // `cmpi` operation followed by the `select` operation: // // %cond = cmpi "predicate" %v0, %v1 // %result = select %cond, %v0, %v1 // // Multiple values are scanned in a linear sequence. This creates a data // dependences that wouldn't exist in a tree reduction, but is easier to // recognize as a reduction by the subsequent passes. static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, ArrayRef values, OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); value = builder.create(loc, cmpOp.getResult(), value, *valueIt); } return value; } // Emit instructions that correspond to the affine map in the lower bound // applied to the respective operands, and compute the maximum value across // the results. Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { SmallVector boundOperands(op.getLowerBoundOperands()); auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(), boundOperands); if (!lbValues) return nullptr; return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues, builder); } // Emit instructions that correspond to the affine map in the upper bound // applied to the respective operands, and compute the minimum value across // the results. Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { SmallVector boundOperands(op.getUpperBoundOperands()); auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(), boundOperands); if (!ubValues) return nullptr; return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues, builder); } namespace { // Affine terminators are removed. class AffineTerminatorLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineTerminatorOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); return matchSuccess(); } }; class AffineForLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineForOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); Value step = rewriter.create(loc, op.getStep()); auto f = rewriter.create(loc, lowerBound, upperBound, step); f.region().getBlocks().clear(); rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); rewriter.eraseOp(op); return matchSuccess(); } }; class AffineIfLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineIfOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); Value zeroConstant = rewriter.create(loc, 0); SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); // Calculate cond as a conjunction without short-circuiting. Value cond = nullptr; for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) { AffineExpr constraintExpr = integerSet.getConstraint(i); bool isEquality = integerSet.isEq(i); // Build and apply an affine expression auto numDims = integerSet.getNumDims(); Value affResult = expandAffineExpr(rewriter, loc, constraintExpr, operandsRef.take_front(numDims), operandsRef.drop_front(numDims)); if (!affResult) return matchFailure(); auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; Value cmpVal = rewriter.create(loc, pred, affResult, zeroConstant); cond = cond ? rewriter.create(loc, cond, cmpVal).getResult() : cmpVal; } cond = cond ? cond : rewriter.create(loc, /*value=*/1, /*width=*/1); bool hasElseRegion = !op.elseRegion().empty(); auto ifOp = rewriter.create(loc, cond, hasElseRegion); rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); ifOp.thenRegion().back().erase(); if (hasElseRegion) { rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back()); ifOp.elseRegion().back().erase(); } // Ok, we're done! rewriter.eraseOp(op); return matchSuccess(); } }; // Convert an "affine.apply" operation into a sequence of arithmetic // operations using the StandardOps dialect. class AffineApplyLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineApplyOp op, PatternRewriter &rewriter) const override { auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), llvm::to_vector<8>(op.getOperands())); if (!maybeExpandedMap) return matchFailure(); rewriter.replaceOp(op, *maybeExpandedMap); return matchSuccess(); } }; // Apply the affine map from an 'affine.load' operation to its operands, and // feed the results to a newly created 'std.load' operation (which replaces the // original 'affine.load'). class AffineLoadLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) return matchFailure(); // Build std.load memref[expandedMap.results]. rewriter.replaceOpWithNewOp(op, op.getMemRef(), *resultOperands); return matchSuccess(); } }; // Apply the affine map from an 'affine.prefetch' operation to its operands, and // feed the results to a newly created 'std.prefetch' operation (which replaces // the original 'affine.prefetch'). class AffinePrefetchLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffinePrefetchOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affinePrefetchOp'. SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) return matchFailure(); // Build std.prefetch memref[expandedMap.results]. rewriter.replaceOpWithNewOp( op, op.memref(), *resultOperands, op.isWrite(), op.localityHint().getZExtValue(), op.isDataCache()); return matchSuccess(); } }; // Apply the affine map from an 'affine.store' operation to its operands, and // feed the results to a newly created 'std.store' operation (which replaces the // original 'affine.store'). class AffineStoreLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineStoreOp'. SmallVector indices(op.getMapOperands()); auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) return matchFailure(); // Build std.store valueToStore, memref[expandedMap.results]. rewriter.replaceOpWithNewOp(op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); return matchSuccess(); } }; // Apply the affine maps from an 'affine.dma_start' operation to each of their // respective map operands, and feed the results to a newly created // 'std.dma_start' operation (which replaces the original 'affine.dma_start'). class AffineDmaStartLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineDmaStartOp op, PatternRewriter &rewriter) const override { SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); // Expand affine map for DMA source memref. auto maybeExpandedSrcMap = expandAffineMap( rewriter, op.getLoc(), op.getSrcMap(), operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1)); if (!maybeExpandedSrcMap) return matchFailure(); // Expand affine map for DMA destination memref. auto maybeExpandedDstMap = expandAffineMap( rewriter, op.getLoc(), op.getDstMap(), operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1)); if (!maybeExpandedDstMap) return matchFailure(); // Expand affine map for DMA tag memref. auto maybeExpandedTagMap = expandAffineMap( rewriter, op.getLoc(), op.getTagMap(), operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1)); if (!maybeExpandedTagMap) return matchFailure(); // Build std.dma_start operation with affine map results. rewriter.replaceOpWithNewOp( op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(), *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(), *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride()); return matchSuccess(); } }; // Apply the affine map from an 'affine.dma_wait' operation tag memref, // and feed the results to a newly created 'std.dma_wait' operation (which // replaces the original 'affine.dma_wait'). class AffineDmaWaitLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AffineDmaWaitOp op, PatternRewriter &rewriter) const override { // Expand affine map for DMA tag memref. SmallVector indices(op.getTagIndices()); auto maybeExpandedTagMap = expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); if (!maybeExpandedTagMap) return matchFailure(); // Build std.dma_wait operation with affine map results. rewriter.replaceOpWithNewOp( op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements()); return matchSuccess(); } }; } // end namespace void mlir::populateAffineToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert< AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering, AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering, AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx); } namespace { class LowerAffinePass : public FunctionPass { void runOnFunction() override { OwningRewritePatternList patterns; populateAffineToStdConversionPatterns(patterns, &getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } }; } // namespace /// Lowers If and For operations within a function into their lower level CFG /// equivalent blocks. std::unique_ptr> mlir::createLowerAffinePass() { return std::make_unique(); } static PassRegistration pass("lower-affine", "Lower If, For, AffineApply operations to primitive equivalents");