diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 346 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp | 600 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp | 238 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 243 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 461 |
5 files changed, 1888 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp new file mode 100644 index 00000000000..9df7bce0879 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -0,0 +1,346 @@ +//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Fusion pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Dominance.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/FoldUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-fusion" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; + +using llvm::dbgs; + +/// Implements a simple high-level fusion pass of linalg library operations. +/// +/// In each block, linalg ops are processed in reverse textual order. +/// Given a linalg op `O`, fusion occurs by: +/// 1. inspecting the linalg ops that write into the views read by `O`. This +/// uses the SSA value of the views and a simple subview/slice analysis to +/// determine producer-consumer dependences; +/// 2. greedily fuse the linalg ops that produce subview +/// 3. inspect the fused ops and determine whether they have other remaining +/// LinalgOp uses. If not, then erase the original producing linalg op. +/// +/// More advanced use cases, analyses as well as profitability heuristics are +/// left for future work. + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); +static llvm::cl::list<unsigned> clTileSizes( + "linalg-fusion-tile-sizes", + llvm::cl::desc( + "Tile sizes by which to tile linalg operations during linalg fusion"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::cat(clOptionsCategory)); + +// Return a cloned version of `op` that operates on `loopRanges`, assumed to be +// a subset of the original loop ranges of `op`. +// This is achieved by applying the `loopToOperandRangesMaps` permutation maps +// to the `loopRanges` in order to obtain view ranges. +static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, + ArrayRef<SubViewOp::Range> loopRanges) { + auto maps = loopToOperandRangesMaps(op); + SmallVector<Value, 8> clonedViews; + clonedViews.reserve(op.getNumInputsAndOutputs()); + // Iterate over the inputs and outputs in order. + // Extract the subranges from the linearized ranges. + SmallVector<Value, 8> ios(op.getInputsAndOutputs()); + for (auto en : llvm::enumerate(ios)) { + unsigned idx = en.index(); + auto map = maps[idx]; + LLVM_DEBUG(dbgs() << "map: " << map << "\n"); + Value view = en.value(); + SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); + for (auto en2 : llvm::enumerate(map.getResults())) { + unsigned d = en2.index(); + // loopToOperandRangesMaps are permutations-only. + unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); + viewRanges[d] = loopRanges[loopPos]; + LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() + << "\t" + << "loopPos: " << loopPos << "\t" << viewRanges[d]); + } + // Construct a new subview for the tile. + unsigned rank = viewRanges.size(); + SmallVector<Value, 4> offsets, sizes, strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (auto r : viewRanges) { + offsets.push_back(r.offset); + sizes.push_back(r.size); + strides.push_back(r.stride); + } + clonedViews.push_back( + b.create<SubViewOp>(loc, view, offsets, sizes, strides)); + } + auto operands = getAssumedNonViewOperands(op); + clonedViews.append(operands.begin(), operands.end()); + return op.clone(b, loc, clonedViews); +} + +struct ViewDimension { + Value view; + unsigned dimension; +}; + +// Given an `op`, returns the first (`view`, `dimension`) pair that identifies +// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps +// guarantees at least one such dimension is found. If multiple candidates exist +// they must agree by construction (i.e. have the same size) and we just return +// the first one. +static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { + auto maps = loopToOperandRangesMaps(op); + // Iterate over the inputs and outputs in order. + // Extract the subranges from the linearized ranges. + SmallVector<Value, 8> ios(op.getInputsAndOutputs()); + for (auto en : llvm::enumerate(ios)) { + unsigned idx = en.index(); + auto map = maps[idx]; + LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); + LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); + Value view = en.value(); + SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); + for (auto en2 : llvm::enumerate(map.getResults())) { + if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { + LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth + << "\n"); + LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view + << "\n"); + return ViewDimension{view, static_cast<unsigned>(en2.index())}; + } + } + } + llvm_unreachable("Expect to be able to extract a view defining loop range"); +} + +static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx, unsigned producerIdx, + OperationFolder *folder) { + auto subView = dyn_cast_or_null<SubViewOp>( + consumer.getInput(consumerIdx)->getDefiningOp()); + auto slice = dyn_cast_or_null<SliceOp>( + consumer.getInput(consumerIdx)->getDefiningOp()); + assert(subView || slice); + (void)subView; + (void)slice; + + // loopToOperandRangesMaps are permutations-only by construction: + // we can always identify a data dimension with a (at least one) loop + // dimension. + AffineMap producerMap = + loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx]; + LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); + + unsigned nPar = producer.getNumParallelLoops(); + unsigned nRed = producer.getNumReductionLoops(); + unsigned nWin = producer.getNumWindowLoops(); + SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); + + // Iterate over dimensions identified by the producer map for `producerIdx`. + // This defines a subset of the loop ranges that we need to complete later. + for (auto en : llvm::enumerate(producerMap.getResults())) { + unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); + loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; + } + + OpBuilder b(consumer.getOperation()); + auto loc = consumer.getLoc(); + // Iterate over all dimensions. For the dimensions not identified by the + // producer map for `producerIdx`, we need to explicitly compute the view that + // defines the loop ranges using the `producer`. + for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { + if (loopRanges[i].offset) + LLVM_DEBUG(llvm::dbgs() + << "existing LoopRange: " << loopRanges[i] << "\n"); + else { + auto viewDim = getViewDefiningLoopRange(producer, i); + loopRanges[i] = SubViewOp::Range{constant_index(folder, 0), + dim(viewDim.view, viewDim.dimension), + constant_index(folder, 1)}; + LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); + } + } + + return cloneWithLoopRanges(b, loc, producer, loopRanges); +} + +// Encode structural fusion safety preconditions. +// Some of these will be lifted in the future with better analysis. +static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, + LinalgOp consumer) { + if (producer.getNumOutputs() != 1) { + LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); + return false; + } + // Only fuse when the producer block dominates. + DominanceInfo dom(producer.getOperation()); + if (!dom.dominates(producer.getOperation()->getBlock(), + consumer.getOperation()->getBlock())) { + LLVM_DEBUG( + dbgs() + << "\nNot structurally fusable (producer block does not dominate)"); + return false; + } + return true; +} + +bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, + Value consumedView, + LinalgOp producer) { + // Make some simple structural checks that alleviate the need for more + // complex analyses. + if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { + LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" + << *producer.getOperation()); + return false; + } + // Check for any interleaved write to consumedView. + if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" + << *producer.getOperation()); + return false; + } + return true; +} + +bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value consumedView, + LinalgOp producer) { + if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) + return false; + // Check for any fusion-preventing dependence to any view read/written that + // would violate dependences. + if (!graph.findCoveringDependences(producer, consumer).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" + << *producer.getOperation()); + return false; + } + return true; +} + +// Only consider RAW atm. +Optional<FusionInfo> mlir::linalg::fuseProducerOf( + OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &graph, OperationFolder *folder) { + LLVM_DEBUG(dbgs() << "\nStart examining consumer: " + << *consumer.getOperation()); + for (auto dependence : graph.getDependencesInto( + consumer, LinalgDependenceGraph::DependenceType::RAW)) { + LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" + << *dependence.dependentOpView.op << "\n"); + auto producer = cast<LinalgOp>(dependence.dependentOpView.op); + + // Check that the dependence is indeed on the input `consumerIdx` view. + auto consumedView = dependence.indexingView; + if (consumer.getInput(consumerIdx) != consumedView) + continue; + + // Consumer consumes this view, `isStructurallyFusableProducer` also checks + // whether it is a strict subview of the producer view. + auto producedView = dependence.dependentOpView.view; + auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); + // `consumerIdx` and `producerIdx` exist by construction. + LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() + << " view: " << *producedView + << " output index: " << producerIdx); + + // Must be a subview or a slice to guarantee there are loops we can fuse + // into. + auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp()); + auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp()); + if (!subView && !slice) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); + continue; + } + + // Simple fusability checks. + if (!isFusableInto(graph, consumer, consumedView, producer)) + continue; + + // Fuse `producer` just before `consumer`. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(consumer.getOperation()); + ScopedContext scope(b, consumer.getLoc()); + LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); + auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, + producerIdx, folder); + + return FusionInfo{producer, fusedProducer}; + } + return llvm::None; +} + +static void fuseLinalgOpsGreedily(FuncOp f) { + LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); + + OpBuilder b(f); + OperationFolder folder(f.getContext()); + DenseSet<Operation *> eraseSet; + + // Save original Linalg ops, we only want to make a pass over those. + SmallVector<Operation *, 8> linalgOps; + f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + + Aliases aliases; + LinalgDependenceGraph G(aliases, linalgOps); + for (auto *op : llvm::reverse(linalgOps)) { + for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs(); + consumerIdx < e; ++consumerIdx) { + if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder)) + eraseSet.insert(fusionInfo->originalProducer.getOperation()); + } + } + + // The `fuseProducerOf` function performs structural checks and in particular + // that no covering read or write exist between the consumer and the producer. + // As a consequence, the only fusions that may occur preserve subsequent + // dependences and are guaranteed by construction to produce the whole view. + // We may thus erase the producer once it is fused. + for (auto *e : eraseSet) + e->erase(); + LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); +} + +namespace { +struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { + void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } +}; +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgFusionPass() { + return std::make_unique<LinalgFusionPass>(); +} + +static PassRegistration<LinalgFusionPass> + pass("linalg-fusion", "Fuse operations in the linalg dialect"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp new file mode 100644 index 00000000000..d7cc4a86d21 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -0,0 +1,600 @@ +//===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; + +using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>; +using IndexedAffineValue = TemplatedIndexedValue<affine_load, affine_store>; + +using edsc::op::operator+; +using edsc::op::operator==; + +static SmallVector<ValueHandle, 8> +makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, + ArrayRef<Value> vals) { + assert(map.getNumSymbols() == 0); + assert(map.getNumInputs() == vals.size()); + SmallVector<ValueHandle, 8> res; + res.reserve(map.getNumResults()); + auto dims = map.getNumDims(); + for (auto e : map.getResults()) { + auto exprMap = AffineMap::get(dims, 0, e); + SmallVector<Value, 4> operands(vals.begin(), vals.end()); + canonicalizeMapAndOperands(&exprMap, &operands); + res.push_back(affine_apply(exprMap, operands)); + } + return res; +} + +static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs, + Optional<AffineMap> permutation) { + return permutation ? applyMapToValues(ScopedContext::getBuilder(), + ScopedContext::getLocation(), + permutation.getValue(), ivs) + : SmallVector<Value, 4>(ivs.begin(), ivs.end()); +} + +// Creates a number of ranges equal to the number of results in `map`. +// The returned ranges correspond to the loop ranges, in the proper order, for +// which new loops will be created. +static SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef<Value> allViewSizes); +SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, + ArrayRef<Value> allViewSizes) { + // Apply `map` to get view sizes in loop order. + auto sizes = applyMapToValues(b, loc, map, allViewSizes); + // Create a new range with the applied tile sizes. + ScopedContext scope(b, loc); + SmallVector<Value, 4> res; + for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { + res.push_back(range(constant_index(0), sizes[idx], constant_index(1))); + } + return res; +} + +template <typename IndexedValueType, typename LinalgOpType> +class LinalgScopedEmitter {}; + +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, CopyOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) { + auto nPar = copyOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto inputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); + auto outputIvs = + permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); + SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); + SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); + IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + // clang-format off + nPar > 0 ? O(oivs) = I(iivs) : + O() = I(); + // clang-format on + } +}; + +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, FillOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) { + auto nPar = fillOp.getNumParallelLoops(); + assert(nPar == allIvs.size()); + auto ivs = + SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar); + IndexedValueType O(fillOp.getOutput(0)); + // Emit the proper scalar assignment, whether we are dealing with a 0-D or + // an n-D loop nest; with or without permutations. + nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) + : O() = ValueHandle(fillOp.value()); + } +}; + +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, DotOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) { + assert(allIvs.size() == 1); + IndexHandle r_i(allIvs[0]); + IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), + C(dotOp.getOutput(0)); + // Emit scalar form. + C() = C() + A(r_i) * B(r_i); + } +}; + +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, MatvecOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, + MatvecOp matvecOp) { + assert(allIvs.size() == 2); + IndexHandle i(allIvs[0]), r_j(allIvs[1]); + IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), + C(matvecOp.getOutput(0)); + // Emit scalar form. + C(i) = C(i) + A(i, r_j) * B(r_j); + } +}; + +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, MatmulOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, + MatmulOp matmulOp) { + assert(allIvs.size() == 3); + IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); + IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), + C(matmulOp.getOutput(0)); + // Emit scalar form. + C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); + } +}; + +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, ConvOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + auto maps = loopToOperandRangesMaps(convOp); + SmallVector<ValueHandle, 8> fIdx( + makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); + SmallVector<ValueHandle, 8> imIdx( + makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); + SmallVector<ValueHandle, 8> oIdx( + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); + IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); + // Emit scalar form. + O(oIdx) += F(fIdx) * I(imIdx); + } +}; + +// Emits the MLIR for the scalar part of the generic op by: +// 1. Emitting std_load and std_store ops for each input and output +// view in order. This is achieved by applying the appropriate input or +// output map to the enclosing induction variables. +// 2. Emitting a call to `op.fun()` that takes as arguments the scalars +// from point 1. above. +// 3. Emitting std_store to store the results of 2. to the output +// views. +// +// An example output may resemble: +// +// ``` +// loop.for %i = %c0 to %0 step %c1 { +// loop.for %j = %c0 to %1 step %c1 { +// loop.for %k = %c0 to %4 step %c1 { +// %11 = load %arg0[%i, %j] : +// memref<?x?xf32, stride_specification> +// %12 = load %arg1[%i, %j, %k] : +// memref<?x?x?xf32, stride_specification> +// %13 = load %arg2[%i, %k, %j] : +// memref<?x?x?xf32, stride_specification> +// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) +// store %14#0, %arg1[%i, %j, %k] : +// memref<?x?x?Xf32, stride_specification> +// store %14#1, %arg2[%i, %k, %j] : +// memref<?x?x?Xf32, stride_specification> +// } +// } +// } +// ``` +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, GenericOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, + GenericOp genericOp) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + using edsc::intrinsics::detail::ValueHandleArray; + unsigned nInputs = genericOp.getNumInputs(); + unsigned nOutputs = genericOp.getNumOutputs(); + SmallVector<Value, 4> indexedValues(nInputs + nOutputs); + + // 1.a. Emit std_load from input views. + for (unsigned i = 0; i < nInputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs)); + indexedValues[i] = std_load(genericOp.getInput(i), indexing); + } + + // 1.b. Emit std_load from output views. + for (unsigned i = 0; i < nOutputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + } + + auto funcOp = genericOp.getFunction(); + if (funcOp) { + // 2. Emit call. + Operation *callOp = call(funcOp, indexedValues); + assert(callOp->getNumResults() == genericOp.getNumOutputs()); + + // 3. Emit std_store. + for (unsigned i = 0; i < nOutputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + } + return; + } + // TODO(ntv): When a region inliner exists, use it. + // 2. Inline region, currently only works for a single basic block. + BlockAndValueMapping map; + auto &block = genericOp.region().front(); + for (auto it : llvm::zip(block.getArguments(), indexedValues)) + map.map(std::get<0>(it), std::get<1>(it)); + for (auto &op : block.without_terminator()) { + assert(op.getNumRegions() == 0); + auto *newOp = b.clone(op, map); + for (auto it : llvm::zip(op.getResults(), newOp->getResults())) + map.map(std::get<0>(it), std::get<1>(it)); + } + + // 3. Emit std_store. + auto *yieldOp = cast<YieldOp>(block.back()).getOperation(); + assert(yieldOp->getNumOperands() == nOutputs); + for (unsigned i = 0; i < nOutputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), + indexing); + } + } +}; + +// Emits the MLIR for the scalar part of the indexed generic op by: +// 1. Emitting std_load and std_store ops for each input and output view in +// order. This is achieved by applying the appropriate input or output map +// to the enclosing induction variables. +// 2. Emitting a call to `op.fun()` that takes as arguments the induction +// variables and the scalars from point 1. above. +// 3. Emitting std_store to store the results of 2. to the output views. +// +// An example output may resemble: +// +// ``` +// loop.for %i = %c0 to %0 step %c1 { +// loop.for %j = %c0 to %1 step %c1 { +// loop.for %k = %c0 to %4 step %c1 { +// %11 = load %arg0[%i, %j] : +// memref<?x?xf32, stride_specification> +// %12 = load %arg1[%i, %j, %k] : +// memref<?x?x?xf32, stride_specification> +// %13 = load %arg2[%i, %k, %j] : +// memref<?x?x?xf32, stride_specification> +// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) : +// (index, index, index, f32, f32, f32) -> (f32, f32) +// store %14#0, %arg1[%i, %j, %k] : +// memref<?x?x?Xf32, stride_specification> +// store %14#1, %arg2[%i, %k, %j] : +// memref<?x?x?Xf32, stride_specification> +// } +// } +// } +// ``` +template <typename IndexedValueType> +class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> { +public: + static void emitScalarImplementation(ArrayRef<Value> allIvs, + IndexedGenericOp indexedGenericOp) { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + using edsc::intrinsics::detail::ValueHandleArray; + unsigned nInputs = indexedGenericOp.getNumInputs(); + unsigned nOutputs = indexedGenericOp.getNumOutputs(); + unsigned nLoops = allIvs.size(); + SmallVector<Value, 4> indexedValues(nLoops + nInputs + nOutputs); + + for (unsigned i = 0; i < nLoops; ++i) { + indexedValues[i] = allIvs[i]; + } + + // 1.a. Emit std_load from input views. + for (unsigned i = 0; i < nInputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); + indexedValues[nLoops + i] = + std_load(indexedGenericOp.getInput(i), indexing); + } + + // 1.b. Emit std_load from output views. + for (unsigned i = 0; i < nOutputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + indexedValues[nLoops + nInputs + i] = + std_load(indexedGenericOp.getOutput(i), indexing); + } + + if (auto funcOp = indexedGenericOp.getFunction()) { + // 2. Emit call. + Operation *callOp = call(funcOp, indexedValues); + assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs()); + + // 3. Emit std_store. + for (unsigned i = 0; i < nOutputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + indexing); + } + return; + } + // TODO(ntv): When a region inliner exists, use it. + // 2. Inline region, currently only works for a single basic block. + BlockAndValueMapping map; + auto &block = indexedGenericOp.region().front(); + for (auto it : llvm::zip(block.getArguments(), indexedValues)) + map.map(std::get<0>(it), std::get<1>(it)); + for (auto &op : block.without_terminator()) { + assert(op.getNumRegions() == 0); + auto *newOp = b.clone(op, map); + for (auto it : llvm::zip(op.getResults(), newOp->getResults())) + map.map(std::get<0>(it), std::get<1>(it)); + } + + // 3. Emit std_store. + auto *yieldOp = cast<YieldOp>(block.back()).getOperation(); + assert(yieldOp->getNumOperands() == nOutputs); + for (unsigned i = 0; i < nOutputs; ++i) { + ValueHandleArray indexing(makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + std_store(map.lookup(yieldOp->getOperand(i)), + indexedGenericOp.getOutput(i), indexing); + } + } +}; + +namespace { +// This struct is for factoring out the implementation and support template +// instantiations in the following 2 cases: +// 1. Appending to a list of patterns via RewritePatternList. +// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. +// The implementation must work both in DRR and inside a RewritePattern. As a +// consequence, (1) it is only allowed to emit new ops if the match is +// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an +// encompassing pattern must take care of the erasure logic. +template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy> +class LinalgOpToLoopsImpl { +public: + static LogicalResult doit(Operation *op, PatternRewriter &rewriter); +}; +} // namespace + +template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy> +LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit( + Operation *op, PatternRewriter &rewriter) { + OpBuilder b(op); + ScopedContext scope(b, op->getLoc()); + + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (which is asserted in the inverse calculation). + auto linalgOp = cast<ConcreteOpTy>(op); + auto invertedMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); + if (!invertedMap) { + LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation( + {}, linalgOp); + return success(); + } + + auto nPar = linalgOp.getNumParallelLoops(); + auto nRed = linalgOp.getNumReductionLoops(); + auto nWin = linalgOp.getNumWindowLoops(); + SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin); + SmallVector<ValueHandle *, 4> allPIvs = + makeHandlePointers(MutableArrayRef<IndexHandle>(allIvs)); + auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), + invertedMap, getViewSizes(linalgOp)); + assert(loopRanges.size() == allIvs.size()); + + LoopNestRangeBuilder(allPIvs, loopRanges)([&] { + auto allIvValues = extractValues(allIvs); + LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation( + allIvValues, linalgOp); + }); + return success(); +} + +template <typename LoopType, typename IndexedValueType, typename ConcreteOp> +class LinalgRewritePattern : public RewritePattern { +public: + explicit LinalgRewritePattern(MLIRContext *context) + : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>; + if (failed(Impl::doit(op, rewriter))) + return matchFailure(); + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +// Helper classes for type list expansion. +template <typename LoopType, typename IndexedValueType, typename... LinalgOps> +class RewritePatternList; + +template <typename LoopType, typename IndexedValueType> +class RewritePatternList<LoopType, IndexedValueType> { +public: + static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} +}; + +template <typename LoopType, typename IndexedValueType, typename ConcreteOp, + typename... LinalgOps> +class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> { +public: + static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns + .insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>( + ctx); + RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build( + patterns, ctx); + } +}; + +/// Populate the given list with patterns that convert from Linalg to LLVM. +template <typename LoopType, typename IndexedValueType> +void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { + RewritePatternList<LoopType, IndexedValueType, +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::build(patterns, ctx); +} + +namespace { +template <typename LoopType, typename IndexedValueType> +struct LowerLinalgToLoopsPass + : public FunctionPass<LowerLinalgToLoopsPass<LoopType, IndexedValueType>> { + void runOnFunction() override; +}; +} // namespace + +// Local folding pattern for AffineApplyOp that we can apply greedily. +// This replaces AffineApplyOp by the proper value in cases where the associated +// map is trivial. A trivial map here is defined as a map with a single result +// and either: +// 1. Zero operand + returns a single AffineConstantExpr +// 2. One operand + returns a single AffineDimExpr +// 3. One operands + returns a single AffineSymbolExpr +// +// In the first case, the AffineApplyOp is replaced by a new constant. In the +// other cases, it is replaced by its unique operand. +struct FoldAffineOp : public RewritePattern { + FoldAffineOp(MLIRContext *context) + : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op); + auto map = affineApplyOp.getAffineMap(); + if (map.getNumResults() != 1 || map.getNumInputs() > 1) + return matchFailure(); + + AffineExpr expr = map.getResult(0); + if (map.getNumInputs() == 0) { + if (auto val = expr.dyn_cast<AffineConstantExpr>()) { + rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue()); + return matchSuccess(); + } + return matchFailure(); + } + if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) { + rewriter.replaceOp(op, op->getOperand(0)); + return matchSuccess(); + } + return matchFailure(); + } +}; + +template <typename LoopType, typename IndexedValueType> +void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() { + auto *context = &this->getContext(); + OwningRewritePatternList patterns; + // Canonicalization and folding patterns applied greedily allow cleaning up + // the emitted IR on the fly. + // TODO(ntv) fold view and subview ops? + FillRewritePatterns<LoopType, IndexedValueType>(patterns, context); + DimOp::getCanonicalizationPatterns(patterns, context); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + patterns.insert<FoldAffineOp>(context); + // Just apply the patterns greedily. + applyPatternsGreedily(this->getFunction(), patterns); +} + +/// Create a pass to convert Linalg operations to loop.for loops and +/// std.load/std.store accesses. +std::unique_ptr<OpPassBase<FuncOp>> +mlir::linalg::createConvertLinalgToLoopsPass() { + return std::make_unique< + LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>(); +} + +/// Create a pass to convert Linalg operations to affine.for loops and +/// affine_load/affine_store accesses. +/// Placeholder for now, this is NYI. +std::unique_ptr<OpPassBase<FuncOp>> +mlir::linalg::createConvertLinalgToAffineLoopsPass() { + return std::make_unique< + LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>(); +} + +// Emits a loop nest of `loop.for` with the proper body for `op`. +template <typename ConcreteOp> +LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl<loop::ForOp, IndexedStdValue, ConcreteOp>::doit( + op, rewriter); +} + +// Emits a loop nest of `affine.for` with the proper body for `op`. +template <typename ConcreteOp> +LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, + Operation *op) { + return LinalgOpToLoopsImpl<AffineForOp, IndexedAffineValue, ConcreteOp>::doit( + op, rewriter); +} + +// TODO(ntv) Need to make these instantiations more future-proof to avoid the +// need to update as soon as we add new ops. +#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ + template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \ + PatternRewriter & rewriter, Operation * op); \ + template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \ + PatternRewriter & rewriter, Operation * op); + +INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp) +INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp) + +static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>> + structuredLoopsPass( + "convert-linalg-to-loops", + "Lower the operations from the linalg dialect into loops"); + +static PassRegistration<LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>> + affineLoopsPass( + "convert-linalg-to-affine-loops", + "Lower the operations from the linalg dialect into affine loops"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp new file mode 100644 index 00000000000..eb23a8ceb1a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -0,0 +1,238 @@ +//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for transforming Linalg operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <type_traits> + +#define DEBUG_TYPE "linalg-transforms" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; + +using llvm::dbgs; +using llvm::SetVector; + +// Marker used as attribute name in generated Linalg rewriting transformations. +const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = + "__internal_linalg_transform__"; + +LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes, + StringRef linalgMarker, ArrayRef<unsigned> permutation) { + assert(permutation.empty() || permutation.size() == sizes.size()); + auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation); + if (!tileRes) + return failure(); + tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + return success(); +} + +LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes, + ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, op->getParentOfType<FuncOp>()); + SmallVector<Operation *, 4> originalProducers; + for (auto operandIdx : operandIndicesToFuse) { + auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G); + if (!fusionRes) { + // Linalg fusion requires tiled loops to even determine whether it is + // possible to fuse. As a consequence, the pattern may fail even though a + // tiled version of op has already been introduced. + // So we need to remove the tiled version ourselves in case of failure. + // Another possibility is to ensure the constraints on the pattern + // guarantee that fusion will occur and just assert here. As we develop + // more complex patterns we can choose what is best. + rewriter.eraseOp(tileRes->loops[0]); + return failure(); + } + fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + originalProducers.push_back(fusionRes->originalProducer); + } + + // The originalProducers can now be safely erased. This is similar to + // SSA-value use-def but in the world of buffer + structured ops. + for (auto *originalProducer : originalProducers) + rewriter.eraseOp(originalProducer); + return success(); +} + +bool mlir::linalg::detail::isProducedByOpOfTypeImpl( + Operation *consumerOp, Value consumedView, + function_ref<bool(Operation *)> isaOpType) { + LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp); + if (!consumer) + return false; + + auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); + if (!maybeConsumerIndex) + return false; + + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, consumer.getParentOfType<FuncOp>()); + for (auto dependence : G.getDependencesInto( + consumer, LinalgDependenceGraph::DependenceType::RAW)) { + auto producer = cast<LinalgOp>(dependence.dependentOpView.op); + if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) + continue; + if (isaOpType(dependence.dependentOpView.op)) + return true; + } + return false; +} + +static bool hasMultiplyAddBody(linalg::GenericOp op) { + auto &r = op.region(); + if (r.empty()) + return false; + if (r.getBlocks().size() != 1) + return false; + auto &ops = r.front().getOperations(); + if (ops.size() != 3) + return false; + + using mlir::matchers::m_Val; + auto a = m_Val(r.front().getArgument(0)); + auto b = m_Val(r.front().getArgument(1)); + auto c = m_Val(r.front().getArgument(2)); + // TODO(ntv) Update this detection once we have matcher support for + // specifying that any permutation of operands matches. + auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); + auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); + auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); + auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); + return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || + pattern3.match(&ops.back()) || pattern4.match(&ops.back()); +} + +// TODO(ntv) should be Tablegen'd from a single source that generates the op +// itself. +static bool isMatmul(linalg::GenericOp genericOp) { + auto *ctx = genericOp.getContext(); + auto m = getAffineDimExpr(0, ctx); + auto n = getAffineDimExpr(1, ctx); + auto k = getAffineDimExpr(2, ctx); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); + auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx); + return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && + genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); +} + +LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter, + Operation *op) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + + // TODO(ntv): This is in fact much more general than just vectorization for + // matmul ops. + auto genericOp = dyn_cast<linalg::GenericOp>(op); + if (!genericOp || !isMatmul(genericOp)) + return failure(); + + // TODO(ntv): non-identity layout. + auto isStaticMemRefWithIdentityLayout = [](Value v) { + auto m = v->getType().dyn_cast<MemRefType>(); + if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) + return false; + return true; + }; + if (!llvm::all_of(genericOp.getInputsAndOutputs(), + isStaticMemRefWithIdentityLayout)) + return failure(); + + edsc::ScopedContext scope(rewriter, op->getLoc()); + using edsc::intrinsics::std_load; + using edsc::intrinsics::std_store; + using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>; + using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>; + auto vA = std_load(vector_type_cast(genericOp.getInput(0))); + auto vB = std_load(vector_type_cast(genericOp.getInput(1))); + auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0)); + auto vC = std_load(vectorMemRefC); + auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), + genericOp.iterator_types()); + std_store(vRes, vectorMemRefC); + return success(); +} + +LogicalResult +mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, + ArrayRef<unsigned> permutation, + StringRef linalgMarker) { + // If permutation is empty, there is nothing to be done. + if (permutation.empty()) + return failure(); + + auto linOp = cast<LinalgOp>(op); + auto permutationMap = inversePermutation( + AffineMap::getPermutationMap(permutation, rewriter.getContext())); + SmallVector<AffineMap, 4> newIndexingMap; + auto indexingMaps = linOp.indexing_maps().getValue(); + for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { + AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue().compose( + permutationMap); + newIndexingMap.push_back(m); + } + auto itTypes = linOp.iterator_types().getValue(); + SmallVector<Attribute, 4> itTypesVector; + for (unsigned i = 0, e = itTypes.size(); i != e; ++i) + itTypesVector.push_back(itTypes[i]); + applyPermutationToVector(itTypesVector, permutation); + op->setAttr(getIndexingMapsAttrName(), + rewriter.getAffineMapArrayAttr(newIndexingMap)); + op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector)); + op->setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); + return success(); +} + +LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, + Operation *op) { + LinalgOp linOp = dyn_cast<LinalgOp>(op); + SetVector<Value> subViews; + for (auto it : linOp.getInputsAndOutputs()) + if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + subViews.insert(sv); + if (!subViews.empty()) { + auto resOp = promoteSubViewOperands(rewriter, linOp, subViews); + return success(resOp); + } + return failure(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp new file mode 100644 index 00000000000..b8b27958ff5 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -0,0 +1,243 @@ +//===- Promotion.cpp - Implementation of linalg Promotion -----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Promotion pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/FoldUtils.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; +using namespace mlir::loop; + +using llvm::SetVector; + +#define DEBUG_TYPE "linalg-promotion" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); +static llvm::cl::opt<bool> clPromoteDynamic( + "test-linalg-promote-dynamic", + llvm::cl::desc("Test generation of dynamic promoted buffers"), + llvm::cl::cat(clOptionsCategory), llvm::cl::init(false)); + +static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) { + auto *ctx = size->getContext(); + auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + if (!dynamicBuffers) + if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp())) + return alloc( + MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); + Value mul = muli(constant_index(width), size); + return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); +} + +// Performs promotion of a `subView` into a local buffer of the size of the +// *ranges* of the `subView`. This produces a buffer whose size may be bigger +// than the actual size of the `subView` at the boundaries. +// This is related to the full/partial tile problem. +// Returns a PromotionInfo containing a `buffer`, `fullLocalView` and +// `partialLocalView` such that: +// * `buffer` is always the size of the full tile. +// * `fullLocalView` is a dense contiguous view into that buffer. +// * `partialLocalView` is a dense non-contiguous slice of `fullLocalView` +// that corresponds to the size of `subView` and accounting for boundary +// effects. +// The point of the full tile buffer is that constant static tile sizes are +// folded and result in a buffer type with statically known size and alignment +// properties. +// To account for general boundary effects, padding must be performed on the +// boundary tiles. For now this is done with an unconditional `fill` op followed +// by a partial `copy` op. +static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, + SubViewOp subView, + bool dynamicBuffers, + OperationFolder *folder) { + auto zero = constant_index(folder, 0); + auto one = constant_index(folder, 1); + + auto viewType = subView.getType(); + auto rank = viewType.getRank(); + Value allocSize = one; + SmallVector<Value, 8> fullRanges, partialRanges; + fullRanges.reserve(rank); + partialRanges.reserve(rank); + for (auto en : llvm::enumerate(subView.getRanges())) { + auto rank = en.index(); + auto rangeValue = en.value(); + Value d = rangeValue.size; + allocSize = muli(folder, allocSize, d).getValue(); + fullRanges.push_back(d); + partialRanges.push_back(range(folder, zero, dim(subView, rank), one)); + } + SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1); + auto buffer = + allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); + auto fullLocalView = view( + MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); + auto partialLocalView = slice(fullLocalView, partialRanges); + return PromotionInfo{buffer, fullLocalView, partialLocalView}; +} + +SmallVector<PromotionInfo, 8> +mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, + ArrayRef<Value> subViews, bool dynamicBuffers, + OperationFolder *folder) { + if (subViews.empty()) + return {}; + + ScopedContext scope(b, loc); + SmallVector<PromotionInfo, 8> res; + res.reserve(subViews.size()); + DenseMap<Value, PromotionInfo> promotionInfoMap; + for (auto v : subViews) { + SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); + auto viewType = subView.getType(); + // TODO(ntv): support more cases than just float. + if (!viewType.getElementType().isa<FloatType>()) + continue; + auto promotionInfo = + promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder); + promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo)); + res.push_back(promotionInfo); + } + + for (auto v : subViews) { + SubViewOp subView = cast<SubViewOp>(v->getDefiningOp()); + auto info = promotionInfoMap.find(v); + if (info == promotionInfoMap.end()) + continue; + // TODO(ntv): value to fill with should be related to the operation. + // For now, just use APFloat(0.0f). + auto t = subView.getType().getElementType().cast<FloatType>(); + Value fillVal = constant_float(folder, APFloat(0.0f), t); + // TODO(ntv): fill is only necessary if `promotionInfo` has a full local + // view that is different from the partial local view and we are on the + // boundary. + fill(info->second.fullLocalView, fillVal); + } + + for (auto v : subViews) { + auto info = promotionInfoMap.find(v); + if (info == promotionInfoMap.end()) + continue; + copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView); + } + return res; +} + +LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, + SetVector<Value> subViews, + bool dynamicBuffers, + OperationFolder *folder) { + // 1. Promote the specified views and use them in the new op. + ScopedContext scope(b, op.getLoc()); + auto promotedBufferAndViews = promoteSubViews( + b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); + SmallVector<Value, 8> opViews; + opViews.reserve(op.getNumInputsAndOutputs()); + SmallVector<std::pair<Value, Value>, 8> writebackViews; + writebackViews.reserve(subViews.size()); + unsigned promotedIdx = 0; + for (auto view : op.getInputsAndOutputs()) { + if (subViews.count(view) != 0) { + opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); + writebackViews.emplace_back(std::make_pair( + view, promotedBufferAndViews[promotedIdx].partialLocalView)); + promotedIdx++; + } else { + opViews.push_back(view); + } + } + + // 2. Append all other operands as they appear, this enforces that such + // operands are not views. This is to support cases such as FillOp taking + // extra scalars etc. + auto operands = getAssumedNonViewOperands(op); + opViews.append(operands.begin(), operands.end()); + LinalgOp res = op.clone(b, op.getLoc(), opViews); + + // 3. Emit write-back for the promoted output views: copy the partial view. + for (auto viewAndPartialLocalView : writebackViews) { + // WARNING: MUST use the old op to determine whether the operand view is an + // output. + bool isOutput = + op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); + if (isOutput) + copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); + } + + // 4. Dealloc local buffers. + for (const auto &pi : promotedBufferAndViews) + dealloc(pi.buffer); + + return res; +} + +static void promoteSubViews(FuncOp f, bool dynamicBuffers) { + SmallVector<LinalgOp, 8> toErase; + OperationFolder folder(f.getContext()); + f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { + // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or + // nothing. + SetVector<Value> subViews; + OpBuilder b(op); + for (auto it : op.getInputsAndOutputs()) + if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + subViews.insert(sv); + if (!subViews.empty()) { + promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); + toErase.push_back(op); + } + }); + for (auto op : toErase) + op.erase(); +} + +namespace { +struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> { + LinalgPromotionPass() = default; + LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {} + + void runOnFunction() override { + promoteSubViews(getFunction(), dynamicBuffers); + } + + bool dynamicBuffers; +}; +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) { + return std::make_unique<LinalgPromotionPass>(dynamicBuffers); +} + +static PassRegistration<LinalgPromotionPass> + pass("linalg-promote-subviews", "promote subview ops to local buffers", [] { + return std::make_unique<LinalgPromotionPass>(clPromoteDynamic); + }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp new file mode 100644 index 00000000000..964f540c099 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -0,0 +1,461 @@ +//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Tiling pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/Transforms/FoldUtils.h" + +#include "llvm/Support/CommandLine.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; +using namespace mlir::linalg::intrinsics; +using namespace mlir::loop; + +#define DEBUG_TYPE "linalg-tiling" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); +static llvm::cl::list<unsigned> + clTileSizes("linalg-tile-sizes", + llvm::cl::desc("Tile sizes by which to tile linalg operations"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::cat(clOptionsCategory)); + +static bool isZero(Value v) { + return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) && + cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0; +} + +using LoopIndexToRangeIndexMap = DenseMap<int, int>; + +// Creates a number of ranges equal to the number of non-zero in `tileSizes`. +// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has +// one entry per surrounding loop. It uses zero as the convention that a +// particular loop is not tiled. This convention simplifies implementations by +// avoiding affine map manipulations. +// The returned ranges correspond to the loop ranges, in the proper order, that +// are tiled and for which new loops will be created. Also the function returns +// a map from loop indices of the LinalgOp to the corresponding non-empty range +// indices of newly created loops. +static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap> +makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, + ArrayRef<Value> allViewSizes, ArrayRef<Value> allTileSizes, + OperationFolder *folder) { + assert(allTileSizes.size() == map.getNumResults()); + // Apply `map` to get view sizes in loop order. + auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder); + SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); + + // Traverse the tile sizes, which are in loop order, erase zeros everywhere. + LoopIndexToRangeIndexMap loopIndexToRangeIndex; + for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { + if (isZero(tileSizes[idx - zerosCount])) { + viewSizes.erase(viewSizes.begin() + idx - zerosCount); + tileSizes.erase(tileSizes.begin() + idx - zerosCount); + ++zerosCount; + continue; + } + loopIndexToRangeIndex[idx] = idx - zerosCount; + } + + // Create a new range with the applied tile sizes. + SmallVector<SubViewOp::Range, 4> res; + for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { + res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx], + tileSizes[idx]}); + } + return std::make_tuple(res, loopIndexToRangeIndex); +} + +namespace { + +// Helper visitor to determine whether an AffineExpr is tiled. +// This is achieved by traversing every AffineDimExpr with position `pos` and +// checking whether the corresponding `tileSizes[pos]` is non-zero. +// This also enforces only positive coefficients occur in multiplications. +// +// Example: +// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] +// +struct TileCheck : public AffineExprVisitor<TileCheck> { + TileCheck(ArrayRef<Value> tileSizes) : isTiled(false), tileSizes(tileSizes) {} + + void visitDimExpr(AffineDimExpr expr) { + isTiled |= !isZero(tileSizes[expr.getPosition()]); + } + void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { + visit(expr.getLHS()); + visit(expr.getRHS()); + if (expr.getKind() == mlir::AffineExprKind::Mul) + assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 && + "nonpositive multiplying coefficient"); + } + bool isTiled; + ArrayRef<Value> tileSizes; +}; + +} // namespace + +// IndexedGenericOp explicitly uses induction variables in the loop body. The +// values of the indices that are used in the loop body for any given access of +// input/output memref before `subview` op was applied should be invariant with +// respect to tiling. +// +// Therefore, if the operation is tiled, we have to transform the indices +// accordingly, i.e. offset them by the values of the corresponding induction +// variables that are captured implicitly in the body of the op. +// +// Example. `linalg.indexed_generic` before tiling: +// +// #id_2d = (i, j) -> (i, j) +// #pointwise_2d_trait = { +// indexing_maps = [#id_2d, #id_2d], +// iterator_types = ["parallel", "parallel"], +// n_views = [1, 1] +// } +// linalg.indexed_generic #pointwise_2d_trait %operand, %result { +// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): +// <some operations that use %i, %j> +// }: memref<50x100xf32>, memref<50x100xf32> +// +// After tiling pass with tiles sizes 10 and 25: +// +// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) +// +// %c1 = constant 1 : index +// %c0 = constant 0 : index +// %c25 = constant 25 : index +// %c10 = constant 10 : index +// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> +// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> +// loop.for %k = %c0 to operand_dim_0 step %c10 { +// loop.for %l = %c0 to operand_dim_1 step %c25 { +// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] +// : memref<50x100xf32> to memref<?x?xf32, #strided> +// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] +// : memref<50x100xf32> to memref<?x?xf32, #strided> +// linalg.indexed_generic pointwise_2d_trait %4, %5 { +// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): +// // Indices `k` and `l` are implicitly captured in the body. +// %transformed_i = addi %i, %k : index // index `i` is offset by %k +// %transformed_j = addi %j, %l : index // index `j` is offset by %l +// // Every use of %i, %j is replaced with %transformed_i, %transformed_j +// <some operations that use %transformed_i, %transformed_j> +// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> +// } +// } +// +// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices +// does not lead to losing information. +void transformIndexedGenericOpIndices( + OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs, + const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation()); + if (!indexedGenericOp) + return; + + // `linalg.indexed_generic` comes in two flavours. One has a region with a + // single block that defines the loop body. The other has a `fun` attribute + // that refers to an existing function symbol. The `fun` function call will be + // inserted in the loop body in that case. + // + // TODO(pifon): Add support for `linalg.indexed_generic` with `fun` attribute. + auto ®ion = indexedGenericOp.region(); + if (region.empty()) { + indexedGenericOp.emitError("op expected a region"); + return; + } + auto &block = region.getBlocks().front(); + + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(&block); + for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) { + auto rangeIndex = loopIndexToRangeIndex.find(i); + if (rangeIndex == loopIndexToRangeIndex.end()) + continue; + Value oldIndex = block.getArgument(i); + // Offset the index argument `i` by the value of the corresponding induction + // variable and replace all uses of the previous value. + Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, + pivs[rangeIndex->second]->getValue()); + for (auto &use : oldIndex->getUses()) { + if (use.getOwner() == newIndex->getDefiningOp()) + continue; + use.set(newIndex); + } + } +} + +static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) { + if (!expr) + return false; + TileCheck t(tileSizes); + t.visit(expr); + return t.isTiled; +} + +// Checks whether the view with index `viewIndex` within `linalgOp` varies with +// respect to a non-zero `tileSize`. +static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) { + if (!map) + return false; + for (unsigned r = 0; r < map.getNumResults(); ++r) + if (isTiled(map.getResult(r), tileSizes)) + return true; + return false; +} + +static SmallVector<Value, 4> +makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, + ArrayRef<Value> ivs, ArrayRef<Value> tileSizes, + ArrayRef<Value> viewSizes, OperationFolder *folder) { + assert(ivs.size() == static_cast<size_t>(llvm::count_if( + llvm::make_range(tileSizes.begin(), tileSizes.end()), + [](Value v) { return !isZero(v); })) && + "expected as many ivs as non-zero sizes"); + + using edsc::intrinsics::select; + using edsc::op::operator+; + using edsc::op::operator<; + + // Construct (potentially temporary) mins and maxes on which to apply maps + // that define tile subviews. + SmallVector<Value, 8> lbs, subViewSizes; + for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { + bool isTiled = !isZero(tileSizes[idx]); + lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)constant_index(folder, 0)); + subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]); + } + + auto *op = linalgOp.getOperation(); + + SmallVector<Value, 4> res; + res.reserve(op->getNumOperands()); + auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); + for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); + ++viewIndex) { + Value view = *(viewIteratorBegin + viewIndex); + unsigned rank = view->getType().cast<MemRefType>().getRank(); + auto map = loopToOperandRangesMaps(linalgOp)[viewIndex]; + // If the view is not tiled, we can use it as is. + if (!isTiled(map, tileSizes)) { + res.push_back(view); + continue; + } + + // Construct a new subview for the tile. + SmallVector<Value, 4> offsets, sizes, strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned r = 0; r < rank; ++r) { + if (!isTiled(map.getSubMap({r}), tileSizes)) { + offsets.push_back(constant_index(folder, 0)); + sizes.push_back(dim(view, r)); + strides.push_back(constant_index(folder, 1)); + continue; + } + + // Tiling creates a new slice at the proper index, the slice step is 1 + // (i.e. the slice view does not subsample, stepping occurs in the loop). + auto m = map.getSubMap({r}); + auto offset = applyMapToValues(b, loc, m, lbs, folder).front(); + offsets.push_back(offset); + auto size = applyMapToValues(b, loc, m, subViewSizes, folder).front(); + sizes.push_back(size); + strides.push_back(constant_index(folder, 1)); + } + // TODO(b/144419024) Atm std.subview is not guaranteed in-bounds. Depending + // on the semantics we attach to it, we may need to use min(size, dim) here + // and canonicalize later. + res.push_back(b.create<SubViewOp>(loc, view, offsets, sizes, strides)); + } + + // Traverse the mins/maxes and erase those that don't have uses left. + // This is a special type of folding that we only apply when `folder` is + // defined. + if (folder) + for (auto v : llvm::concat<Value>(lbs, subViewSizes)) + if (v->use_empty()) + v->getDefiningOp()->erase(); + + return res; +} + +Optional<TiledLinalgOp> +mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes, + ArrayRef<unsigned> permutation, + OperationFolder *folder) { + // 1. Enforce the convention that "tiling by zero" skips tiling a particular + // dimension. This convention is significantly simpler to handle instead of + // adjusting affine maps to account for missing dimensions. + assert(op.getNumParallelLoops() + op.getNumReductionLoops() + + op.getNumWindowLoops() == + tileSizes.size() && + "expected matching number of tile sizes and loops"); + + // If permutation is empty, use the identity. Build the permutation map + // otherwise. + auto invPermutationMap = AffineMap::getMultiDimIdentityMap( + tileSizes.size(), ScopedContext::getContext()); + if (!permutation.empty()) + invPermutationMap = inversePermutation( + AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); + + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + ScopedContext scope(b, op.getLoc()); + // 2. Build the tiled loop ranges. + auto viewSizes = getViewSizes(op); + // The flattened loopToOperandRangesMaps is expected to be an invertible + // permutation map (asserted in the inverse calculation). + auto viewSizesToLoopsMap = + inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op))); + assert(viewSizesToLoopsMap && "expected invertible map"); + + SmallVector<SubViewOp::Range, 4> loopRanges; + LoopIndexToRangeIndexMap loopIndexToRangeIndex; + std::tie(loopRanges, loopIndexToRangeIndex) = + makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap, + viewSizes, tileSizes, folder); + if (!permutation.empty()) + applyPermutationToVector(loopRanges, permutation); + + // 3. Create the tiled loops. + LinalgOp res = op; + SmallVector<IndexHandle, 4> ivs(loopRanges.size()); + auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); + LoopNestRangeBuilder(pivs, loopRanges)([&] { + auto b = ScopedContext::getBuilder(); + auto loc = ScopedContext::getLocation(); + SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end()); + + // If we have to apply a permutation to the tiled loop nest, we have to + // reorder the induction variables This permutation is the right one + // assuming that loopRanges have previously been permuted by + // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of + // that one: (d0,d1,d2)->(d2,d0,d1) + if (!permutation.empty()) + ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder); + + auto views = + makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder); + auto operands = getAssumedNonViewOperands(op); + views.append(operands.begin(), operands.end()); + res = op.clone(b, loc, views); + }); + + // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. + transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex); + + // 5. Gather the newly created loops and return them with the new op. + SmallVector<ForOp, 8> loops; + loops.reserve(ivs.size()); + for (auto iv : ivs) + loops.push_back(loop::getForInductionVarOwner(iv)); + + return TiledLinalgOp{res, loops}; +} + +Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp( + OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes, + ArrayRef<unsigned> permutation, OperationFolder *folder) { + if (tileSizes.empty()) + return llvm::None; + + // The following uses the convention that "tiling by zero" skips tiling a + // particular dimension. This convention is significantly simpler to handle + // instead of adjusting affine maps to account for missing dimensions. + auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() + + op.getNumWindowLoops(); + tileSizes = tileSizes.take_front(nLoops); + // If only 0 tilings are left, then return. + if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; })) + return llvm::None; + + // Create a builder for tile size constants. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + ScopedContext scope(b, op.getLoc()); + + // Materialize concrete tile size values to pass the generic tiling function. + SmallVector<Value, 8> tileSizeValues; + tileSizeValues.reserve(tileSizes.size()); + for (auto ts : tileSizes) + tileSizeValues.push_back(constant_index(folder, ts)); + // Pad tile sizes with zero values to enforce our convention. + if (tileSizeValues.size() < nLoops) { + for (unsigned i = tileSizeValues.size(); i < nLoops; ++i) + tileSizeValues.push_back(constant_index(folder, 0)); + } + + return tileLinalgOp(b, op, tileSizeValues, permutation, folder); +} + +static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) { + OpBuilder b(f); + OperationFolder folder(f.getContext()); + f.walk([tileSizes, &b, &folder](LinalgOp op) { + auto opLoopsPair = + tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder); + // If tiling occurred successfully, erase old op. + if (opLoopsPair) + op.erase(); + }); + f.walk([](LinalgOp op) { + if (!op.getOperation()->hasNoSideEffect()) + return; + if (op.getOperation()->use_empty()) + op.erase(); + }); +} + +namespace { +struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> { + LinalgTilingPass() = default; + LinalgTilingPass(ArrayRef<int64_t> sizes); + + void runOnFunction() override { tileLinalgOps(getFunction(), tileSizes); } + + SmallVector<int64_t, 8> tileSizes; +}; +} // namespace + +LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes) { + this->tileSizes.assign(sizes.begin(), sizes.end()); +} + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) { + return std::make_unique<LinalgTilingPass>(tileSizes); +} + +static PassRegistration<LinalgTilingPass> + pass("linalg-tile", "Tile operations in the linalg dialect", [] { + auto pass = std::make_unique<LinalgTilingPass>(); + pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); + return pass; + }); |