summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp346
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp600
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp238
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp243
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp461
5 files changed, 1888 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
new file mode 100644
index 00000000000..9df7bce0879
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -0,0 +1,346 @@
+//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the linalg dialect Fusion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-fusion"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using llvm::dbgs;
+
+/// Implements a simple high-level fusion pass of linalg library operations.
+///
+/// In each block, linalg ops are processed in reverse textual order.
+/// Given a linalg op `O`, fusion occurs by:
+/// 1. inspecting the linalg ops that write into the views read by `O`. This
+/// uses the SSA value of the views and a simple subview/slice analysis to
+/// determine producer-consumer dependences;
+/// 2. greedily fuse the linalg ops that produce subview
+/// 3. inspect the fused ops and determine whether they have other remaining
+/// LinalgOp uses. If not, then erase the original producing linalg op.
+///
+/// More advanced use cases, analyses as well as profitability heuristics are
+/// left for future work.
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+static llvm::cl::list<unsigned> clTileSizes(
+ "linalg-fusion-tile-sizes",
+ llvm::cl::desc(
+ "Tile sizes by which to tile linalg operations during linalg fusion"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clOptionsCategory));
+
+// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
+// a subset of the original loop ranges of `op`.
+// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
+// to the `loopRanges` in order to obtain view ranges.
+static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
+ ArrayRef<SubViewOp::Range> loopRanges) {
+ auto maps = loopToOperandRangesMaps(op);
+ SmallVector<Value, 8> clonedViews;
+ clonedViews.reserve(op.getNumInputsAndOutputs());
+ // Iterate over the inputs and outputs in order.
+ // Extract the subranges from the linearized ranges.
+ SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ for (auto en : llvm::enumerate(ios)) {
+ unsigned idx = en.index();
+ auto map = maps[idx];
+ LLVM_DEBUG(dbgs() << "map: " << map << "\n");
+ Value view = en.value();
+ SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
+ for (auto en2 : llvm::enumerate(map.getResults())) {
+ unsigned d = en2.index();
+ // loopToOperandRangesMaps are permutations-only.
+ unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
+ viewRanges[d] = loopRanges[loopPos];
+ LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
+ << "\t"
+ << "loopPos: " << loopPos << "\t" << viewRanges[d]);
+ }
+ // Construct a new subview for the tile.
+ unsigned rank = viewRanges.size();
+ SmallVector<Value, 4> offsets, sizes, strides;
+ offsets.reserve(rank);
+ sizes.reserve(rank);
+ strides.reserve(rank);
+ for (auto r : viewRanges) {
+ offsets.push_back(r.offset);
+ sizes.push_back(r.size);
+ strides.push_back(r.stride);
+ }
+ clonedViews.push_back(
+ b.create<SubViewOp>(loc, view, offsets, sizes, strides));
+ }
+ auto operands = getAssumedNonViewOperands(op);
+ clonedViews.append(operands.begin(), operands.end());
+ return op.clone(b, loc, clonedViews);
+}
+
+struct ViewDimension {
+ Value view;
+ unsigned dimension;
+};
+
+// Given an `op`, returns the first (`view`, `dimension`) pair that identifies
+// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
+// guarantees at least one such dimension is found. If multiple candidates exist
+// they must agree by construction (i.e. have the same size) and we just return
+// the first one.
+static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
+ auto maps = loopToOperandRangesMaps(op);
+ // Iterate over the inputs and outputs in order.
+ // Extract the subranges from the linearized ranges.
+ SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ for (auto en : llvm::enumerate(ios)) {
+ unsigned idx = en.index();
+ auto map = maps[idx];
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
+ Value view = en.value();
+ SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
+ for (auto en2 : llvm::enumerate(map.getResults())) {
+ if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+ << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
+ << "\n");
+ return ViewDimension{view, static_cast<unsigned>(en2.index())};
+ }
+ }
+ }
+ llvm_unreachable("Expect to be able to extract a view defining loop range");
+}
+
+static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
+ unsigned consumerIdx, unsigned producerIdx,
+ OperationFolder *folder) {
+ auto subView = dyn_cast_or_null<SubViewOp>(
+ consumer.getInput(consumerIdx)->getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(
+ consumer.getInput(consumerIdx)->getDefiningOp());
+ assert(subView || slice);
+ (void)subView;
+ (void)slice;
+
+ // loopToOperandRangesMaps are permutations-only by construction:
+ // we can always identify a data dimension with a (at least one) loop
+ // dimension.
+ AffineMap producerMap =
+ loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
+ LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
+ << ", producer map: " << producerMap << "\n");
+
+ unsigned nPar = producer.getNumParallelLoops();
+ unsigned nRed = producer.getNumReductionLoops();
+ unsigned nWin = producer.getNumWindowLoops();
+ SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
+
+ // Iterate over dimensions identified by the producer map for `producerIdx`.
+ // This defines a subset of the loop ranges that we need to complete later.
+ for (auto en : llvm::enumerate(producerMap.getResults())) {
+ unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+ loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
+ }
+
+ OpBuilder b(consumer.getOperation());
+ auto loc = consumer.getLoc();
+ // Iterate over all dimensions. For the dimensions not identified by the
+ // producer map for `producerIdx`, we need to explicitly compute the view that
+ // defines the loop ranges using the `producer`.
+ for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
+ if (loopRanges[i].offset)
+ LLVM_DEBUG(llvm::dbgs()
+ << "existing LoopRange: " << loopRanges[i] << "\n");
+ else {
+ auto viewDim = getViewDefiningLoopRange(producer, i);
+ loopRanges[i] = SubViewOp::Range{constant_index(folder, 0),
+ dim(viewDim.view, viewDim.dimension),
+ constant_index(folder, 1)};
+ LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
+ }
+ }
+
+ return cloneWithLoopRanges(b, loc, producer, loopRanges);
+}
+
+// Encode structural fusion safety preconditions.
+// Some of these will be lifted in the future with better analysis.
+static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
+ LinalgOp consumer) {
+ if (producer.getNumOutputs() != 1) {
+ LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
+ return false;
+ }
+ // Only fuse when the producer block dominates.
+ DominanceInfo dom(producer.getOperation());
+ if (!dom.dominates(producer.getOperation()->getBlock(),
+ consumer.getOperation()->getBlock())) {
+ LLVM_DEBUG(
+ dbgs()
+ << "\nNot structurally fusable (producer block does not dominate)");
+ return false;
+ }
+ return true;
+}
+
+bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
+ LinalgOp consumer,
+ Value consumedView,
+ LinalgOp producer) {
+ // Make some simple structural checks that alleviate the need for more
+ // complex analyses.
+ if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
+ LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ // Check for any interleaved write to consumedView.
+ if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
+ LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ return true;
+}
+
+bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
+ LinalgOp consumer, Value consumedView,
+ LinalgOp producer) {
+ if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
+ return false;
+ // Check for any fusion-preventing dependence to any view read/written that
+ // would violate dependences.
+ if (!graph.findCoveringDependences(producer, consumer).empty()) {
+ LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ return true;
+}
+
+// Only consider RAW atm.
+Optional<FusionInfo> mlir::linalg::fuseProducerOf(
+ OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
+ const LinalgDependenceGraph &graph, OperationFolder *folder) {
+ LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
+ << *consumer.getOperation());
+ for (auto dependence : graph.getDependencesInto(
+ consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+ LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
+ << *dependence.dependentOpView.op << "\n");
+ auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+
+ // Check that the dependence is indeed on the input `consumerIdx` view.
+ auto consumedView = dependence.indexingView;
+ if (consumer.getInput(consumerIdx) != consumedView)
+ continue;
+
+ // Consumer consumes this view, `isStructurallyFusableProducer` also checks
+ // whether it is a strict subview of the producer view.
+ auto producedView = dependence.dependentOpView.view;
+ auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
+ // `consumerIdx` and `producerIdx` exist by construction.
+ LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
+ << " view: " << *producedView
+ << " output index: " << producerIdx);
+
+ // Must be a subview or a slice to guarantee there are loops we can fuse
+ // into.
+ auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp());
+ if (!subView && !slice) {
+ LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
+ continue;
+ }
+
+ // Simple fusability checks.
+ if (!isFusableInto(graph, consumer, consumedView, producer))
+ continue;
+
+ // Fuse `producer` just before `consumer`.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(consumer.getOperation());
+ ScopedContext scope(b, consumer.getLoc());
+ LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+ auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
+ producerIdx, folder);
+
+ return FusionInfo{producer, fusedProducer};
+ }
+ return llvm::None;
+}
+
+static void fuseLinalgOpsGreedily(FuncOp f) {
+ LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
+
+ OpBuilder b(f);
+ OperationFolder folder(f.getContext());
+ DenseSet<Operation *> eraseSet;
+
+ // Save original Linalg ops, we only want to make a pass over those.
+ SmallVector<Operation *, 8> linalgOps;
+ f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
+
+ Aliases aliases;
+ LinalgDependenceGraph G(aliases, linalgOps);
+ for (auto *op : llvm::reverse(linalgOps)) {
+ for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs();
+ consumerIdx < e; ++consumerIdx) {
+ if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder))
+ eraseSet.insert(fusionInfo->originalProducer.getOperation());
+ }
+ }
+
+ // The `fuseProducerOf` function performs structural checks and in particular
+ // that no covering read or write exist between the consumer and the producer.
+ // As a consequence, the only fusions that may occur preserve subsequent
+ // dependences and are guaranteed by construction to produce the whole view.
+ // We may thus erase the producer once it is fused.
+ for (auto *e : eraseSet)
+ e->erase();
+ LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
+}
+
+namespace {
+struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
+ void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
+};
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgFusionPass() {
+ return std::make_unique<LinalgFusionPass>();
+}
+
+static PassRegistration<LinalgFusionPass>
+ pass("linalg-fusion", "Fuse operations in the linalg dialect");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
new file mode 100644
index 00000000000..d7cc4a86d21
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -0,0 +1,600 @@
+//===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using IndexedStdValue = TemplatedIndexedValue<std_load, std_store>;
+using IndexedAffineValue = TemplatedIndexedValue<affine_load, affine_store>;
+
+using edsc::op::operator+;
+using edsc::op::operator==;
+
+static SmallVector<ValueHandle, 8>
+makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
+ ArrayRef<Value> vals) {
+ assert(map.getNumSymbols() == 0);
+ assert(map.getNumInputs() == vals.size());
+ SmallVector<ValueHandle, 8> res;
+ res.reserve(map.getNumResults());
+ auto dims = map.getNumDims();
+ for (auto e : map.getResults()) {
+ auto exprMap = AffineMap::get(dims, 0, e);
+ SmallVector<Value, 4> operands(vals.begin(), vals.end());
+ canonicalizeMapAndOperands(&exprMap, &operands);
+ res.push_back(affine_apply(exprMap, operands));
+ }
+ return res;
+}
+
+static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
+ Optional<AffineMap> permutation) {
+ return permutation ? applyMapToValues(ScopedContext::getBuilder(),
+ ScopedContext::getLocation(),
+ permutation.getValue(), ivs)
+ : SmallVector<Value, 4>(ivs.begin(), ivs.end());
+}
+
+// Creates a number of ranges equal to the number of results in `map`.
+// The returned ranges correspond to the loop ranges, in the proper order, for
+// which new loops will be created.
+static SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc,
+ AffineMap map,
+ ArrayRef<Value> allViewSizes);
+SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+ ArrayRef<Value> allViewSizes) {
+ // Apply `map` to get view sizes in loop order.
+ auto sizes = applyMapToValues(b, loc, map, allViewSizes);
+ // Create a new range with the applied tile sizes.
+ ScopedContext scope(b, loc);
+ SmallVector<Value, 4> res;
+ for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
+ res.push_back(range(constant_index(0), sizes[idx], constant_index(1)));
+ }
+ return res;
+}
+
+template <typename IndexedValueType, typename LinalgOpType>
+class LinalgScopedEmitter {};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, CopyOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+ auto nPar = copyOp.getNumParallelLoops();
+ assert(nPar == allIvs.size());
+ auto inputIvs =
+ permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
+ auto outputIvs =
+ permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
+ SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
+ SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
+ IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0));
+ // Emit the proper scalar assignment, whether we are dealing with a 0-D or
+ // an n-D loop nest; with or without permutations.
+ // clang-format off
+ nPar > 0 ? O(oivs) = I(iivs) :
+ O() = I();
+ // clang-format on
+ }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, FillOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+ auto nPar = fillOp.getNumParallelLoops();
+ assert(nPar == allIvs.size());
+ auto ivs =
+ SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
+ IndexedValueType O(fillOp.getOutput(0));
+ // Emit the proper scalar assignment, whether we are dealing with a 0-D or
+ // an n-D loop nest; with or without permutations.
+ nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
+ : O() = ValueHandle(fillOp.value());
+ }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, DotOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
+ assert(allIvs.size() == 1);
+ IndexHandle r_i(allIvs[0]);
+ IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
+ C(dotOp.getOutput(0));
+ // Emit scalar form.
+ C() = C() + A(r_i) * B(r_i);
+ }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
+ MatvecOp matvecOp) {
+ assert(allIvs.size() == 2);
+ IndexHandle i(allIvs[0]), r_j(allIvs[1]);
+ IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
+ C(matvecOp.getOutput(0));
+ // Emit scalar form.
+ C(i) = C(i) + A(i, r_j) * B(r_j);
+ }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
+ MatmulOp matmulOp) {
+ assert(allIvs.size() == 3);
+ IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
+ IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
+ C(matmulOp.getOutput(0));
+ // Emit scalar form.
+ C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
+ }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, ConvOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
+ auto b = ScopedContext::getBuilder();
+ auto loc = ScopedContext::getLocation();
+ auto maps = loopToOperandRangesMaps(convOp);
+ SmallVector<ValueHandle, 8> fIdx(
+ makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
+ SmallVector<ValueHandle, 8> imIdx(
+ makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
+ SmallVector<ValueHandle, 8> oIdx(
+ makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
+ IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
+ // Emit scalar form.
+ O(oIdx) += F(fIdx) * I(imIdx);
+ }
+};
+
+// Emits the MLIR for the scalar part of the generic op by:
+// 1. Emitting std_load and std_store ops for each input and output
+// view in order. This is achieved by applying the appropriate input or
+// output map to the enclosing induction variables.
+// 2. Emitting a call to `op.fun()` that takes as arguments the scalars
+// from point 1. above.
+// 3. Emitting std_store to store the results of 2. to the output
+// views.
+//
+// An example output may resemble:
+//
+// ```
+// loop.for %i = %c0 to %0 step %c1 {
+// loop.for %j = %c0 to %1 step %c1 {
+// loop.for %k = %c0 to %4 step %c1 {
+// %11 = load %arg0[%i, %j] :
+// memref<?x?xf32, stride_specification>
+// %12 = load %arg1[%i, %j, %k] :
+// memref<?x?x?xf32, stride_specification>
+// %13 = load %arg2[%i, %k, %j] :
+// memref<?x?x?xf32, stride_specification>
+// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
+// store %14#0, %arg1[%i, %j, %k] :
+// memref<?x?x?Xf32, stride_specification>
+// store %14#1, %arg2[%i, %k, %j] :
+// memref<?x?x?Xf32, stride_specification>
+// }
+// }
+// }
+// ```
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, GenericOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
+ GenericOp genericOp) {
+ auto b = ScopedContext::getBuilder();
+ auto loc = ScopedContext::getLocation();
+ using edsc::intrinsics::detail::ValueHandleArray;
+ unsigned nInputs = genericOp.getNumInputs();
+ unsigned nOutputs = genericOp.getNumOutputs();
+ SmallVector<Value, 4> indexedValues(nInputs + nOutputs);
+
+ // 1.a. Emit std_load from input views.
+ for (unsigned i = 0; i < nInputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getInputIndexingMap(i), allIvs));
+ indexedValues[i] = std_load(genericOp.getInput(i), indexing);
+ }
+
+ // 1.b. Emit std_load from output views.
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+ indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
+ }
+
+ auto funcOp = genericOp.getFunction();
+ if (funcOp) {
+ // 2. Emit call.
+ Operation *callOp = call(funcOp, indexedValues);
+ assert(callOp->getNumResults() == genericOp.getNumOutputs());
+
+ // 3. Emit std_store.
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+ std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
+ }
+ return;
+ }
+ // TODO(ntv): When a region inliner exists, use it.
+ // 2. Inline region, currently only works for a single basic block.
+ BlockAndValueMapping map;
+ auto &block = genericOp.region().front();
+ for (auto it : llvm::zip(block.getArguments(), indexedValues))
+ map.map(std::get<0>(it), std::get<1>(it));
+ for (auto &op : block.without_terminator()) {
+ assert(op.getNumRegions() == 0);
+ auto *newOp = b.clone(op, map);
+ for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
+ map.map(std::get<0>(it), std::get<1>(it));
+ }
+
+ // 3. Emit std_store.
+ auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+ assert(yieldOp->getNumOperands() == nOutputs);
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+ std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
+ indexing);
+ }
+ }
+};
+
+// Emits the MLIR for the scalar part of the indexed generic op by:
+// 1. Emitting std_load and std_store ops for each input and output view in
+// order. This is achieved by applying the appropriate input or output map
+// to the enclosing induction variables.
+// 2. Emitting a call to `op.fun()` that takes as arguments the induction
+// variables and the scalars from point 1. above.
+// 3. Emitting std_store to store the results of 2. to the output views.
+//
+// An example output may resemble:
+//
+// ```
+// loop.for %i = %c0 to %0 step %c1 {
+// loop.for %j = %c0 to %1 step %c1 {
+// loop.for %k = %c0 to %4 step %c1 {
+// %11 = load %arg0[%i, %j] :
+// memref<?x?xf32, stride_specification>
+// %12 = load %arg1[%i, %j, %k] :
+// memref<?x?x?xf32, stride_specification>
+// %13 = load %arg2[%i, %k, %j] :
+// memref<?x?x?xf32, stride_specification>
+// %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
+// (index, index, index, f32, f32, f32) -> (f32, f32)
+// store %14#0, %arg1[%i, %j, %k] :
+// memref<?x?x?Xf32, stride_specification>
+// store %14#1, %arg2[%i, %k, %j] :
+// memref<?x?x?Xf32, stride_specification>
+// }
+// }
+// }
+// ```
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
+public:
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
+ IndexedGenericOp indexedGenericOp) {
+ auto b = ScopedContext::getBuilder();
+ auto loc = ScopedContext::getLocation();
+ using edsc::intrinsics::detail::ValueHandleArray;
+ unsigned nInputs = indexedGenericOp.getNumInputs();
+ unsigned nOutputs = indexedGenericOp.getNumOutputs();
+ unsigned nLoops = allIvs.size();
+ SmallVector<Value, 4> indexedValues(nLoops + nInputs + nOutputs);
+
+ for (unsigned i = 0; i < nLoops; ++i) {
+ indexedValues[i] = allIvs[i];
+ }
+
+ // 1.a. Emit std_load from input views.
+ for (unsigned i = 0; i < nInputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
+ indexedValues[nLoops + i] =
+ std_load(indexedGenericOp.getInput(i), indexing);
+ }
+
+ // 1.b. Emit std_load from output views.
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+ indexedValues[nLoops + nInputs + i] =
+ std_load(indexedGenericOp.getOutput(i), indexing);
+ }
+
+ if (auto funcOp = indexedGenericOp.getFunction()) {
+ // 2. Emit call.
+ Operation *callOp = call(funcOp, indexedValues);
+ assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs());
+
+ // 3. Emit std_store.
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+ std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+ indexing);
+ }
+ return;
+ }
+ // TODO(ntv): When a region inliner exists, use it.
+ // 2. Inline region, currently only works for a single basic block.
+ BlockAndValueMapping map;
+ auto &block = indexedGenericOp.region().front();
+ for (auto it : llvm::zip(block.getArguments(), indexedValues))
+ map.map(std::get<0>(it), std::get<1>(it));
+ for (auto &op : block.without_terminator()) {
+ assert(op.getNumRegions() == 0);
+ auto *newOp = b.clone(op, map);
+ for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
+ map.map(std::get<0>(it), std::get<1>(it));
+ }
+
+ // 3. Emit std_store.
+ auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+ assert(yieldOp->getNumOperands() == nOutputs);
+ for (unsigned i = 0; i < nOutputs; ++i) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+ std_store(map.lookup(yieldOp->getOperand(i)),
+ indexedGenericOp.getOutput(i), indexing);
+ }
+ }
+};
+
+namespace {
+// This struct is for factoring out the implementation and support template
+// instantiations in the following 2 cases:
+// 1. Appending to a list of patterns via RewritePatternList.
+// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
+// The implementation must work both in DRR and inside a RewritePattern. As a
+// consequence, (1) it is only allowed to emit new ops if the match is
+// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
+// encompassing pattern must take care of the erasure logic.
+template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
+class LinalgOpToLoopsImpl {
+public:
+ static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
+};
+} // namespace
+
+template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
+LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
+ Operation *op, PatternRewriter &rewriter) {
+ OpBuilder b(op);
+ ScopedContext scope(b, op->getLoc());
+
+ // The flattened loopToOperandRangesMaps is expected to be an invertible
+ // permutation map (which is asserted in the inverse calculation).
+ auto linalgOp = cast<ConcreteOpTy>(op);
+ auto invertedMap =
+ inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
+ if (!invertedMap) {
+ LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
+ {}, linalgOp);
+ return success();
+ }
+
+ auto nPar = linalgOp.getNumParallelLoops();
+ auto nRed = linalgOp.getNumReductionLoops();
+ auto nWin = linalgOp.getNumWindowLoops();
+ SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
+ SmallVector<ValueHandle *, 4> allPIvs =
+ makeHandlePointers(MutableArrayRef<IndexHandle>(allIvs));
+ auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
+ invertedMap, getViewSizes(linalgOp));
+ assert(loopRanges.size() == allIvs.size());
+
+ LoopNestRangeBuilder(allPIvs, loopRanges)([&] {
+ auto allIvValues = extractValues(allIvs);
+ LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
+ allIvValues, linalgOp);
+ });
+ return success();
+}
+
+template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
+class LinalgRewritePattern : public RewritePattern {
+public:
+ explicit LinalgRewritePattern(MLIRContext *context)
+ : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
+ if (failed(Impl::doit(op, rewriter)))
+ return matchFailure();
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+
+// Helper classes for type list expansion.
+template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
+class RewritePatternList;
+
+template <typename LoopType, typename IndexedValueType>
+class RewritePatternList<LoopType, IndexedValueType> {
+public:
+ static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
+};
+
+template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
+ typename... LinalgOps>
+class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
+public:
+ static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns
+ .insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
+ ctx);
+ RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build(
+ patterns, ctx);
+ }
+};
+
+/// Populate the given list with patterns that convert from Linalg to LLVM.
+template <typename LoopType, typename IndexedValueType>
+void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ RewritePatternList<LoopType, IndexedValueType,
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >::build(patterns, ctx);
+}
+
+namespace {
+template <typename LoopType, typename IndexedValueType>
+struct LowerLinalgToLoopsPass
+ : public FunctionPass<LowerLinalgToLoopsPass<LoopType, IndexedValueType>> {
+ void runOnFunction() override;
+};
+} // namespace
+
+// Local folding pattern for AffineApplyOp that we can apply greedily.
+// This replaces AffineApplyOp by the proper value in cases where the associated
+// map is trivial. A trivial map here is defined as a map with a single result
+// and either:
+// 1. Zero operand + returns a single AffineConstantExpr
+// 2. One operand + returns a single AffineDimExpr
+// 3. One operands + returns a single AffineSymbolExpr
+//
+// In the first case, the AffineApplyOp is replaced by a new constant. In the
+// other cases, it is replaced by its unique operand.
+struct FoldAffineOp : public RewritePattern {
+ FoldAffineOp(MLIRContext *context)
+ : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
+ auto map = affineApplyOp.getAffineMap();
+ if (map.getNumResults() != 1 || map.getNumInputs() > 1)
+ return matchFailure();
+
+ AffineExpr expr = map.getResult(0);
+ if (map.getNumInputs() == 0) {
+ if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+ if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
+ rewriter.replaceOp(op, op->getOperand(0));
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
+template <typename LoopType, typename IndexedValueType>
+void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() {
+ auto *context = &this->getContext();
+ OwningRewritePatternList patterns;
+ // Canonicalization and folding patterns applied greedily allow cleaning up
+ // the emitted IR on the fly.
+ // TODO(ntv) fold view and subview ops?
+ FillRewritePatterns<LoopType, IndexedValueType>(patterns, context);
+ DimOp::getCanonicalizationPatterns(patterns, context);
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ patterns.insert<FoldAffineOp>(context);
+ // Just apply the patterns greedily.
+ applyPatternsGreedily(this->getFunction(), patterns);
+}
+
+/// Create a pass to convert Linalg operations to loop.for loops and
+/// std.load/std.store accesses.
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::linalg::createConvertLinalgToLoopsPass() {
+ return std::make_unique<
+ LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>();
+}
+
+/// Create a pass to convert Linalg operations to affine.for loops and
+/// affine_load/affine_store accesses.
+/// Placeholder for now, this is NYI.
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::linalg::createConvertLinalgToAffineLoopsPass() {
+ return std::make_unique<
+ LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>();
+}
+
+// Emits a loop nest of `loop.for` with the proper body for `op`.
+template <typename ConcreteOp>
+LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+ Operation *op) {
+ return LinalgOpToLoopsImpl<loop::ForOp, IndexedStdValue, ConcreteOp>::doit(
+ op, rewriter);
+}
+
+// Emits a loop nest of `affine.for` with the proper body for `op`.
+template <typename ConcreteOp>
+LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
+ Operation *op) {
+ return LinalgOpToLoopsImpl<AffineForOp, IndexedAffineValue, ConcreteOp>::doit(
+ op, rewriter);
+}
+
+// TODO(ntv) Need to make these instantiations more future-proof to avoid the
+// need to update as soon as we add new ops.
+#define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \
+ template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
+ PatternRewriter & rewriter, Operation * op); \
+ template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
+ PatternRewriter & rewriter, Operation * op);
+
+INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
+
+static PassRegistration<LowerLinalgToLoopsPass<loop::ForOp, IndexedStdValue>>
+ structuredLoopsPass(
+ "convert-linalg-to-loops",
+ "Lower the operations from the linalg dialect into loops");
+
+static PassRegistration<LowerLinalgToLoopsPass<AffineForOp, IndexedAffineValue>>
+ affineLoopsPass(
+ "convert-linalg-to-affine-loops",
+ "Lower the operations from the linalg dialect into affine loops");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
new file mode 100644
index 00000000000..eb23a8ceb1a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -0,0 +1,238 @@
+//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for transforming Linalg operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <type_traits>
+
+#define DEBUG_TYPE "linalg-transforms"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using llvm::dbgs;
+using llvm::SetVector;
+
+// Marker used as attribute name in generated Linalg rewriting transformations.
+const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
+ "__internal_linalg_transform__";
+
+LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+ StringRef linalgMarker, ArrayRef<unsigned> permutation) {
+ assert(permutation.empty() || permutation.size() == sizes.size());
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation);
+ if (!tileRes)
+ return failure();
+ tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ return success();
+}
+
+LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+ if (!tileRes)
+ return failure();
+ tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ Aliases aliases;
+ auto G = LinalgDependenceGraph::buildDependenceGraph(
+ aliases, op->getParentOfType<FuncOp>());
+ SmallVector<Operation *, 4> originalProducers;
+ for (auto operandIdx : operandIndicesToFuse) {
+ auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G);
+ if (!fusionRes) {
+ // Linalg fusion requires tiled loops to even determine whether it is
+ // possible to fuse. As a consequence, the pattern may fail even though a
+ // tiled version of op has already been introduced.
+ // So we need to remove the tiled version ourselves in case of failure.
+ // Another possibility is to ensure the constraints on the pattern
+ // guarantee that fusion will occur and just assert here. As we develop
+ // more complex patterns we can choose what is best.
+ rewriter.eraseOp(tileRes->loops[0]);
+ return failure();
+ }
+ fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ originalProducers.push_back(fusionRes->originalProducer);
+ }
+
+ // The originalProducers can now be safely erased. This is similar to
+ // SSA-value use-def but in the world of buffer + structured ops.
+ for (auto *originalProducer : originalProducers)
+ rewriter.eraseOp(originalProducer);
+ return success();
+}
+
+bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
+ Operation *consumerOp, Value consumedView,
+ function_ref<bool(Operation *)> isaOpType) {
+ LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
+ if (!consumer)
+ return false;
+
+ auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView);
+ if (!maybeConsumerIndex)
+ return false;
+
+ Aliases aliases;
+ auto G = LinalgDependenceGraph::buildDependenceGraph(
+ aliases, consumer.getParentOfType<FuncOp>());
+ for (auto dependence : G.getDependencesInto(
+ consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+ auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+ if (!isProducerLastWriteOfView(G, consumer, consumedView, producer))
+ continue;
+ if (isaOpType(dependence.dependentOpView.op))
+ return true;
+ }
+ return false;
+}
+
+static bool hasMultiplyAddBody(linalg::GenericOp op) {
+ auto &r = op.region();
+ if (r.empty())
+ return false;
+ if (r.getBlocks().size() != 1)
+ return false;
+ auto &ops = r.front().getOperations();
+ if (ops.size() != 3)
+ return false;
+
+ using mlir::matchers::m_Val;
+ auto a = m_Val(r.front().getArgument(0));
+ auto b = m_Val(r.front().getArgument(1));
+ auto c = m_Val(r.front().getArgument(2));
+ // TODO(ntv) Update this detection once we have matcher support for
+ // specifying that any permutation of operands matches.
+ auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
+ auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
+ auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
+ auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
+ return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) ||
+ pattern3.match(&ops.back()) || pattern4.match(&ops.back());
+}
+
+// TODO(ntv) should be Tablegen'd from a single source that generates the op
+// itself.
+static bool isMatmul(linalg::GenericOp genericOp) {
+ auto *ctx = genericOp.getContext();
+ auto m = getAffineDimExpr(0, ctx);
+ auto n = getAffineDimExpr(1, ctx);
+ auto k = getAffineDimExpr(2, ctx);
+ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
+ auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
+ auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
+ auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx);
+ return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
+ genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
+}
+
+LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
+ Operation *op) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+ "]: Rewrite linalg op as vector.contract: "
+ << *op << ":\n");
+
+ // TODO(ntv): This is in fact much more general than just vectorization for
+ // matmul ops.
+ auto genericOp = dyn_cast<linalg::GenericOp>(op);
+ if (!genericOp || !isMatmul(genericOp))
+ return failure();
+
+ // TODO(ntv): non-identity layout.
+ auto isStaticMemRefWithIdentityLayout = [](Value v) {
+ auto m = v->getType().dyn_cast<MemRefType>();
+ if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
+ return false;
+ return true;
+ };
+ if (!llvm::all_of(genericOp.getInputsAndOutputs(),
+ isStaticMemRefWithIdentityLayout))
+ return failure();
+
+ edsc::ScopedContext scope(rewriter, op->getLoc());
+ using edsc::intrinsics::std_load;
+ using edsc::intrinsics::std_store;
+ using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
+ using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
+ auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
+ auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
+ auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0));
+ auto vC = std_load(vectorMemRefC);
+ auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
+ genericOp.iterator_types());
+ std_store(vRes, vectorMemRefC);
+ return success();
+}
+
+LogicalResult
+mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
+ ArrayRef<unsigned> permutation,
+ StringRef linalgMarker) {
+ // If permutation is empty, there is nothing to be done.
+ if (permutation.empty())
+ return failure();
+
+ auto linOp = cast<LinalgOp>(op);
+ auto permutationMap = inversePermutation(
+ AffineMap::getPermutationMap(permutation, rewriter.getContext()));
+ SmallVector<AffineMap, 4> newIndexingMap;
+ auto indexingMaps = linOp.indexing_maps().getValue();
+ for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
+ AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue().compose(
+ permutationMap);
+ newIndexingMap.push_back(m);
+ }
+ auto itTypes = linOp.iterator_types().getValue();
+ SmallVector<Attribute, 4> itTypesVector;
+ for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
+ itTypesVector.push_back(itTypes[i]);
+ applyPermutationToVector(itTypesVector, permutation);
+ op->setAttr(getIndexingMapsAttrName(),
+ rewriter.getAffineMapArrayAttr(newIndexingMap));
+ op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector));
+ op->setAttr(LinalgTransforms::kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
+ return success();
+}
+
+LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
+ Operation *op) {
+ LinalgOp linOp = dyn_cast<LinalgOp>(op);
+ SetVector<Value> subViews;
+ for (auto it : linOp.getInputsAndOutputs())
+ if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
+ subViews.insert(sv);
+ if (!subViews.empty()) {
+ auto resOp = promoteSubViewOperands(rewriter, linOp, subViews);
+ return success(resOp);
+ }
+ return failure();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
new file mode 100644
index 00000000000..b8b27958ff5
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -0,0 +1,243 @@
+//===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the linalg dialect Promotion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+using namespace mlir::loop;
+
+using llvm::SetVector;
+
+#define DEBUG_TYPE "linalg-promotion"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+static llvm::cl::opt<bool> clPromoteDynamic(
+ "test-linalg-promote-dynamic",
+ llvm::cl::desc("Test generation of dynamic promoted buffers"),
+ llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
+
+static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
+ auto *ctx = size->getContext();
+ auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+ if (!dynamicBuffers)
+ if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
+ return alloc(
+ MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
+ Value mul = muli(constant_index(width), size);
+ return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
+}
+
+// Performs promotion of a `subView` into a local buffer of the size of the
+// *ranges* of the `subView`. This produces a buffer whose size may be bigger
+// than the actual size of the `subView` at the boundaries.
+// This is related to the full/partial tile problem.
+// Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
+// `partialLocalView` such that:
+// * `buffer` is always the size of the full tile.
+// * `fullLocalView` is a dense contiguous view into that buffer.
+// * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
+// that corresponds to the size of `subView` and accounting for boundary
+// effects.
+// The point of the full tile buffer is that constant static tile sizes are
+// folded and result in a buffer type with statically known size and alignment
+// properties.
+// To account for general boundary effects, padding must be performed on the
+// boundary tiles. For now this is done with an unconditional `fill` op followed
+// by a partial `copy` op.
+static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
+ SubViewOp subView,
+ bool dynamicBuffers,
+ OperationFolder *folder) {
+ auto zero = constant_index(folder, 0);
+ auto one = constant_index(folder, 1);
+
+ auto viewType = subView.getType();
+ auto rank = viewType.getRank();
+ Value allocSize = one;
+ SmallVector<Value, 8> fullRanges, partialRanges;
+ fullRanges.reserve(rank);
+ partialRanges.reserve(rank);
+ for (auto en : llvm::enumerate(subView.getRanges())) {
+ auto rank = en.index();
+ auto rangeValue = en.value();
+ Value d = rangeValue.size;
+ allocSize = muli(folder, allocSize, d).getValue();
+ fullRanges.push_back(d);
+ partialRanges.push_back(range(folder, zero, dim(subView, rank), one));
+ }
+ SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
+ auto buffer =
+ allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
+ auto fullLocalView = view(
+ MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
+ auto partialLocalView = slice(fullLocalView, partialRanges);
+ return PromotionInfo{buffer, fullLocalView, partialLocalView};
+}
+
+SmallVector<PromotionInfo, 8>
+mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
+ ArrayRef<Value> subViews, bool dynamicBuffers,
+ OperationFolder *folder) {
+ if (subViews.empty())
+ return {};
+
+ ScopedContext scope(b, loc);
+ SmallVector<PromotionInfo, 8> res;
+ res.reserve(subViews.size());
+ DenseMap<Value, PromotionInfo> promotionInfoMap;
+ for (auto v : subViews) {
+ SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
+ auto viewType = subView.getType();
+ // TODO(ntv): support more cases than just float.
+ if (!viewType.getElementType().isa<FloatType>())
+ continue;
+ auto promotionInfo =
+ promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
+ promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
+ res.push_back(promotionInfo);
+ }
+
+ for (auto v : subViews) {
+ SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
+ auto info = promotionInfoMap.find(v);
+ if (info == promotionInfoMap.end())
+ continue;
+ // TODO(ntv): value to fill with should be related to the operation.
+ // For now, just use APFloat(0.0f).
+ auto t = subView.getType().getElementType().cast<FloatType>();
+ Value fillVal = constant_float(folder, APFloat(0.0f), t);
+ // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
+ // view that is different from the partial local view and we are on the
+ // boundary.
+ fill(info->second.fullLocalView, fillVal);
+ }
+
+ for (auto v : subViews) {
+ auto info = promotionInfoMap.find(v);
+ if (info == promotionInfoMap.end())
+ continue;
+ copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView);
+ }
+ return res;
+}
+
+LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
+ SetVector<Value> subViews,
+ bool dynamicBuffers,
+ OperationFolder *folder) {
+ // 1. Promote the specified views and use them in the new op.
+ ScopedContext scope(b, op.getLoc());
+ auto promotedBufferAndViews = promoteSubViews(
+ b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
+ SmallVector<Value, 8> opViews;
+ opViews.reserve(op.getNumInputsAndOutputs());
+ SmallVector<std::pair<Value, Value>, 8> writebackViews;
+ writebackViews.reserve(subViews.size());
+ unsigned promotedIdx = 0;
+ for (auto view : op.getInputsAndOutputs()) {
+ if (subViews.count(view) != 0) {
+ opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
+ writebackViews.emplace_back(std::make_pair(
+ view, promotedBufferAndViews[promotedIdx].partialLocalView));
+ promotedIdx++;
+ } else {
+ opViews.push_back(view);
+ }
+ }
+
+ // 2. Append all other operands as they appear, this enforces that such
+ // operands are not views. This is to support cases such as FillOp taking
+ // extra scalars etc.
+ auto operands = getAssumedNonViewOperands(op);
+ opViews.append(operands.begin(), operands.end());
+ LinalgOp res = op.clone(b, op.getLoc(), opViews);
+
+ // 3. Emit write-back for the promoted output views: copy the partial view.
+ for (auto viewAndPartialLocalView : writebackViews) {
+ // WARNING: MUST use the old op to determine whether the operand view is an
+ // output.
+ bool isOutput =
+ op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
+ if (isOutput)
+ copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first);
+ }
+
+ // 4. Dealloc local buffers.
+ for (const auto &pi : promotedBufferAndViews)
+ dealloc(pi.buffer);
+
+ return res;
+}
+
+static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
+ SmallVector<LinalgOp, 8> toErase;
+ OperationFolder folder(f.getContext());
+ f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
+ // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
+ // nothing.
+ SetVector<Value> subViews;
+ OpBuilder b(op);
+ for (auto it : op.getInputsAndOutputs())
+ if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
+ subViews.insert(sv);
+ if (!subViews.empty()) {
+ promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
+ toErase.push_back(op);
+ }
+ });
+ for (auto op : toErase)
+ op.erase();
+}
+
+namespace {
+struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
+ LinalgPromotionPass() = default;
+ LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {}
+
+ void runOnFunction() override {
+ promoteSubViews(getFunction(), dynamicBuffers);
+ }
+
+ bool dynamicBuffers;
+};
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) {
+ return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
+}
+
+static PassRegistration<LinalgPromotionPass>
+ pass("linalg-promote-subviews", "promote subview ops to local buffers", [] {
+ return std::make_unique<LinalgPromotionPass>(clPromoteDynamic);
+ });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
new file mode 100644
index 00000000000..964f540c099
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -0,0 +1,461 @@
+//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the linalg dialect Tiling pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+using namespace mlir::loop;
+
+#define DEBUG_TYPE "linalg-tiling"
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+static llvm::cl::list<unsigned>
+ clTileSizes("linalg-tile-sizes",
+ llvm::cl::desc("Tile sizes by which to tile linalg operations"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clOptionsCategory));
+
+static bool isZero(Value v) {
+ return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
+ cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
+}
+
+using LoopIndexToRangeIndexMap = DenseMap<int, int>;
+
+// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
+// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
+// one entry per surrounding loop. It uses zero as the convention that a
+// particular loop is not tiled. This convention simplifies implementations by
+// avoiding affine map manipulations.
+// The returned ranges correspond to the loop ranges, in the proper order, that
+// are tiled and for which new loops will be created. Also the function returns
+// a map from loop indices of the LinalgOp to the corresponding non-empty range
+// indices of newly created loops.
+static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
+makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+ ArrayRef<Value> allViewSizes, ArrayRef<Value> allTileSizes,
+ OperationFolder *folder) {
+ assert(allTileSizes.size() == map.getNumResults());
+ // Apply `map` to get view sizes in loop order.
+ auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
+ SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
+
+ // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
+ LoopIndexToRangeIndexMap loopIndexToRangeIndex;
+ for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
+ if (isZero(tileSizes[idx - zerosCount])) {
+ viewSizes.erase(viewSizes.begin() + idx - zerosCount);
+ tileSizes.erase(tileSizes.begin() + idx - zerosCount);
+ ++zerosCount;
+ continue;
+ }
+ loopIndexToRangeIndex[idx] = idx - zerosCount;
+ }
+
+ // Create a new range with the applied tile sizes.
+ SmallVector<SubViewOp::Range, 4> res;
+ for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
+ res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
+ tileSizes[idx]});
+ }
+ return std::make_tuple(res, loopIndexToRangeIndex);
+}
+
+namespace {
+
+// Helper visitor to determine whether an AffineExpr is tiled.
+// This is achieved by traversing every AffineDimExpr with position `pos` and
+// checking whether the corresponding `tileSizes[pos]` is non-zero.
+// This also enforces only positive coefficients occur in multiplications.
+//
+// Example:
+// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
+//
+struct TileCheck : public AffineExprVisitor<TileCheck> {
+ TileCheck(ArrayRef<Value> tileSizes) : isTiled(false), tileSizes(tileSizes) {}
+
+ void visitDimExpr(AffineDimExpr expr) {
+ isTiled |= !isZero(tileSizes[expr.getPosition()]);
+ }
+ void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+ visit(expr.getLHS());
+ visit(expr.getRHS());
+ if (expr.getKind() == mlir::AffineExprKind::Mul)
+ assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
+ "nonpositive multiplying coefficient");
+ }
+ bool isTiled;
+ ArrayRef<Value> tileSizes;
+};
+
+} // namespace
+
+// IndexedGenericOp explicitly uses induction variables in the loop body. The
+// values of the indices that are used in the loop body for any given access of
+// input/output memref before `subview` op was applied should be invariant with
+// respect to tiling.
+//
+// Therefore, if the operation is tiled, we have to transform the indices
+// accordingly, i.e. offset them by the values of the corresponding induction
+// variables that are captured implicitly in the body of the op.
+//
+// Example. `linalg.indexed_generic` before tiling:
+//
+// #id_2d = (i, j) -> (i, j)
+// #pointwise_2d_trait = {
+// indexing_maps = [#id_2d, #id_2d],
+// iterator_types = ["parallel", "parallel"],
+// n_views = [1, 1]
+// }
+// linalg.indexed_generic #pointwise_2d_trait %operand, %result {
+// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
+// <some operations that use %i, %j>
+// }: memref<50x100xf32>, memref<50x100xf32>
+//
+// After tiling pass with tiles sizes 10 and 25:
+//
+// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
+//
+// %c1 = constant 1 : index
+// %c0 = constant 0 : index
+// %c25 = constant 25 : index
+// %c10 = constant 10 : index
+// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
+// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
+// loop.for %k = %c0 to operand_dim_0 step %c10 {
+// loop.for %l = %c0 to operand_dim_1 step %c25 {
+// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
+// : memref<50x100xf32> to memref<?x?xf32, #strided>
+// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
+// : memref<50x100xf32> to memref<?x?xf32, #strided>
+// linalg.indexed_generic pointwise_2d_trait %4, %5 {
+// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
+// // Indices `k` and `l` are implicitly captured in the body.
+// %transformed_i = addi %i, %k : index // index `i` is offset by %k
+// %transformed_j = addi %j, %l : index // index `j` is offset by %l
+// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
+// <some operations that use %transformed_i, %transformed_j>
+// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
+// }
+// }
+//
+// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
+// does not lead to losing information.
+void transformIndexedGenericOpIndices(
+ OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
+ const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
+ auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
+ if (!indexedGenericOp)
+ return;
+
+ // `linalg.indexed_generic` comes in two flavours. One has a region with a
+ // single block that defines the loop body. The other has a `fun` attribute
+ // that refers to an existing function symbol. The `fun` function call will be
+ // inserted in the loop body in that case.
+ //
+ // TODO(pifon): Add support for `linalg.indexed_generic` with `fun` attribute.
+ auto &region = indexedGenericOp.region();
+ if (region.empty()) {
+ indexedGenericOp.emitError("op expected a region");
+ return;
+ }
+ auto &block = region.getBlocks().front();
+
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(&block);
+ for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) {
+ auto rangeIndex = loopIndexToRangeIndex.find(i);
+ if (rangeIndex == loopIndexToRangeIndex.end())
+ continue;
+ Value oldIndex = block.getArgument(i);
+ // Offset the index argument `i` by the value of the corresponding induction
+ // variable and replace all uses of the previous value.
+ Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
+ pivs[rangeIndex->second]->getValue());
+ for (auto &use : oldIndex->getUses()) {
+ if (use.getOwner() == newIndex->getDefiningOp())
+ continue;
+ use.set(newIndex);
+ }
+ }
+}
+
+static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
+ if (!expr)
+ return false;
+ TileCheck t(tileSizes);
+ t.visit(expr);
+ return t.isTiled;
+}
+
+// Checks whether the view with index `viewIndex` within `linalgOp` varies with
+// respect to a non-zero `tileSize`.
+static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
+ if (!map)
+ return false;
+ for (unsigned r = 0; r < map.getNumResults(); ++r)
+ if (isTiled(map.getResult(r), tileSizes))
+ return true;
+ return false;
+}
+
+static SmallVector<Value, 4>
+makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
+ ArrayRef<Value> ivs, ArrayRef<Value> tileSizes,
+ ArrayRef<Value> viewSizes, OperationFolder *folder) {
+ assert(ivs.size() == static_cast<size_t>(llvm::count_if(
+ llvm::make_range(tileSizes.begin(), tileSizes.end()),
+ [](Value v) { return !isZero(v); })) &&
+ "expected as many ivs as non-zero sizes");
+
+ using edsc::intrinsics::select;
+ using edsc::op::operator+;
+ using edsc::op::operator<;
+
+ // Construct (potentially temporary) mins and maxes on which to apply maps
+ // that define tile subviews.
+ SmallVector<Value, 8> lbs, subViewSizes;
+ for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
+ bool isTiled = !isZero(tileSizes[idx]);
+ lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)constant_index(folder, 0));
+ subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]);
+ }
+
+ auto *op = linalgOp.getOperation();
+
+ SmallVector<Value, 4> res;
+ res.reserve(op->getNumOperands());
+ auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin();
+ for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
+ ++viewIndex) {
+ Value view = *(viewIteratorBegin + viewIndex);
+ unsigned rank = view->getType().cast<MemRefType>().getRank();
+ auto map = loopToOperandRangesMaps(linalgOp)[viewIndex];
+ // If the view is not tiled, we can use it as is.
+ if (!isTiled(map, tileSizes)) {
+ res.push_back(view);
+ continue;
+ }
+
+ // Construct a new subview for the tile.
+ SmallVector<Value, 4> offsets, sizes, strides;
+ offsets.reserve(rank);
+ sizes.reserve(rank);
+ strides.reserve(rank);
+ for (unsigned r = 0; r < rank; ++r) {
+ if (!isTiled(map.getSubMap({r}), tileSizes)) {
+ offsets.push_back(constant_index(folder, 0));
+ sizes.push_back(dim(view, r));
+ strides.push_back(constant_index(folder, 1));
+ continue;
+ }
+
+ // Tiling creates a new slice at the proper index, the slice step is 1
+ // (i.e. the slice view does not subsample, stepping occurs in the loop).
+ auto m = map.getSubMap({r});
+ auto offset = applyMapToValues(b, loc, m, lbs, folder).front();
+ offsets.push_back(offset);
+ auto size = applyMapToValues(b, loc, m, subViewSizes, folder).front();
+ sizes.push_back(size);
+ strides.push_back(constant_index(folder, 1));
+ }
+ // TODO(b/144419024) Atm std.subview is not guaranteed in-bounds. Depending
+ // on the semantics we attach to it, we may need to use min(size, dim) here
+ // and canonicalize later.
+ res.push_back(b.create<SubViewOp>(loc, view, offsets, sizes, strides));
+ }
+
+ // Traverse the mins/maxes and erase those that don't have uses left.
+ // This is a special type of folding that we only apply when `folder` is
+ // defined.
+ if (folder)
+ for (auto v : llvm::concat<Value>(lbs, subViewSizes))
+ if (v->use_empty())
+ v->getDefiningOp()->erase();
+
+ return res;
+}
+
+Optional<TiledLinalgOp>
+mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
+ ArrayRef<unsigned> permutation,
+ OperationFolder *folder) {
+ // 1. Enforce the convention that "tiling by zero" skips tiling a particular
+ // dimension. This convention is significantly simpler to handle instead of
+ // adjusting affine maps to account for missing dimensions.
+ assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
+ op.getNumWindowLoops() ==
+ tileSizes.size() &&
+ "expected matching number of tile sizes and loops");
+
+ // If permutation is empty, use the identity. Build the permutation map
+ // otherwise.
+ auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
+ tileSizes.size(), ScopedContext::getContext());
+ if (!permutation.empty())
+ invPermutationMap = inversePermutation(
+ AffineMap::getPermutationMap(permutation, ScopedContext::getContext()));
+
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ ScopedContext scope(b, op.getLoc());
+ // 2. Build the tiled loop ranges.
+ auto viewSizes = getViewSizes(op);
+ // The flattened loopToOperandRangesMaps is expected to be an invertible
+ // permutation map (asserted in the inverse calculation).
+ auto viewSizesToLoopsMap =
+ inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
+ assert(viewSizesToLoopsMap && "expected invertible map");
+
+ SmallVector<SubViewOp::Range, 4> loopRanges;
+ LoopIndexToRangeIndexMap loopIndexToRangeIndex;
+ std::tie(loopRanges, loopIndexToRangeIndex) =
+ makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
+ viewSizes, tileSizes, folder);
+ if (!permutation.empty())
+ applyPermutationToVector(loopRanges, permutation);
+
+ // 3. Create the tiled loops.
+ LinalgOp res = op;
+ SmallVector<IndexHandle, 4> ivs(loopRanges.size());
+ auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
+ LoopNestRangeBuilder(pivs, loopRanges)([&] {
+ auto b = ScopedContext::getBuilder();
+ auto loc = ScopedContext::getLocation();
+ SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
+
+ // If we have to apply a permutation to the tiled loop nest, we have to
+ // reorder the induction variables This permutation is the right one
+ // assuming that loopRanges have previously been permuted by
+ // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
+ // that one: (d0,d1,d2)->(d2,d0,d1)
+ if (!permutation.empty())
+ ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder);
+
+ auto views =
+ makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
+ auto operands = getAssumedNonViewOperands(op);
+ views.append(operands.begin(), operands.end());
+ res = op.clone(b, loc, views);
+ });
+
+ // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
+ transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex);
+
+ // 5. Gather the newly created loops and return them with the new op.
+ SmallVector<ForOp, 8> loops;
+ loops.reserve(ivs.size());
+ for (auto iv : ivs)
+ loops.push_back(loop::getForInductionVarOwner(iv));
+
+ return TiledLinalgOp{res, loops};
+}
+
+Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
+ OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> permutation, OperationFolder *folder) {
+ if (tileSizes.empty())
+ return llvm::None;
+
+ // The following uses the convention that "tiling by zero" skips tiling a
+ // particular dimension. This convention is significantly simpler to handle
+ // instead of adjusting affine maps to account for missing dimensions.
+ auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() +
+ op.getNumWindowLoops();
+ tileSizes = tileSizes.take_front(nLoops);
+ // If only 0 tilings are left, then return.
+ if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
+ return llvm::None;
+
+ // Create a builder for tile size constants.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ ScopedContext scope(b, op.getLoc());
+
+ // Materialize concrete tile size values to pass the generic tiling function.
+ SmallVector<Value, 8> tileSizeValues;
+ tileSizeValues.reserve(tileSizes.size());
+ for (auto ts : tileSizes)
+ tileSizeValues.push_back(constant_index(folder, ts));
+ // Pad tile sizes with zero values to enforce our convention.
+ if (tileSizeValues.size() < nLoops) {
+ for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
+ tileSizeValues.push_back(constant_index(folder, 0));
+ }
+
+ return tileLinalgOp(b, op, tileSizeValues, permutation, folder);
+}
+
+static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
+ OpBuilder b(f);
+ OperationFolder folder(f.getContext());
+ f.walk([tileSizes, &b, &folder](LinalgOp op) {
+ auto opLoopsPair =
+ tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder);
+ // If tiling occurred successfully, erase old op.
+ if (opLoopsPair)
+ op.erase();
+ });
+ f.walk([](LinalgOp op) {
+ if (!op.getOperation()->hasNoSideEffect())
+ return;
+ if (op.getOperation()->use_empty())
+ op.erase();
+ });
+}
+
+namespace {
+struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> {
+ LinalgTilingPass() = default;
+ LinalgTilingPass(ArrayRef<int64_t> sizes);
+
+ void runOnFunction() override { tileLinalgOps(getFunction(), tileSizes); }
+
+ SmallVector<int64_t, 8> tileSizes;
+};
+} // namespace
+
+LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes) {
+ this->tileSizes.assign(sizes.begin(), sizes.end());
+}
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
+ return std::make_unique<LinalgTilingPass>(tileSizes);
+}
+
+static PassRegistration<LinalgTilingPass>
+ pass("linalg-tile", "Tile operations in the linalg dialect", [] {
+ auto pass = std::make_unique<LinalgTilingPass>();
+ pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
+ return pass;
+ });
OpenPOWER on IntegriCloud