summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp346
1 files changed, 346 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
new file mode 100644
index 00000000000..9df7bce0879
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -0,0 +1,346 @@
+//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the linalg dialect Fusion pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-fusion"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+using namespace mlir::linalg::intrinsics;
+
+using llvm::dbgs;
+
+/// Implements a simple high-level fusion pass of linalg library operations.
+///
+/// In each block, linalg ops are processed in reverse textual order.
+/// Given a linalg op `O`, fusion occurs by:
+/// 1. inspecting the linalg ops that write into the views read by `O`. This
+/// uses the SSA value of the views and a simple subview/slice analysis to
+/// determine producer-consumer dependences;
+/// 2. greedily fuse the linalg ops that produce subview
+/// 3. inspect the fused ops and determine whether they have other remaining
+/// LinalgOp uses. If not, then erase the original producing linalg op.
+///
+/// More advanced use cases, analyses as well as profitability heuristics are
+/// left for future work.
+
+static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
+static llvm::cl::list<unsigned> clTileSizes(
+ "linalg-fusion-tile-sizes",
+ llvm::cl::desc(
+ "Tile sizes by which to tile linalg operations during linalg fusion"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::cat(clOptionsCategory));
+
+// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
+// a subset of the original loop ranges of `op`.
+// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
+// to the `loopRanges` in order to obtain view ranges.
+static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
+ ArrayRef<SubViewOp::Range> loopRanges) {
+ auto maps = loopToOperandRangesMaps(op);
+ SmallVector<Value, 8> clonedViews;
+ clonedViews.reserve(op.getNumInputsAndOutputs());
+ // Iterate over the inputs and outputs in order.
+ // Extract the subranges from the linearized ranges.
+ SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ for (auto en : llvm::enumerate(ios)) {
+ unsigned idx = en.index();
+ auto map = maps[idx];
+ LLVM_DEBUG(dbgs() << "map: " << map << "\n");
+ Value view = en.value();
+ SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
+ for (auto en2 : llvm::enumerate(map.getResults())) {
+ unsigned d = en2.index();
+ // loopToOperandRangesMaps are permutations-only.
+ unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
+ viewRanges[d] = loopRanges[loopPos];
+ LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
+ << "\t"
+ << "loopPos: " << loopPos << "\t" << viewRanges[d]);
+ }
+ // Construct a new subview for the tile.
+ unsigned rank = viewRanges.size();
+ SmallVector<Value, 4> offsets, sizes, strides;
+ offsets.reserve(rank);
+ sizes.reserve(rank);
+ strides.reserve(rank);
+ for (auto r : viewRanges) {
+ offsets.push_back(r.offset);
+ sizes.push_back(r.size);
+ strides.push_back(r.stride);
+ }
+ clonedViews.push_back(
+ b.create<SubViewOp>(loc, view, offsets, sizes, strides));
+ }
+ auto operands = getAssumedNonViewOperands(op);
+ clonedViews.append(operands.begin(), operands.end());
+ return op.clone(b, loc, clonedViews);
+}
+
+struct ViewDimension {
+ Value view;
+ unsigned dimension;
+};
+
+// Given an `op`, returns the first (`view`, `dimension`) pair that identifies
+// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
+// guarantees at least one such dimension is found. If multiple candidates exist
+// they must agree by construction (i.e. have the same size) and we just return
+// the first one.
+static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
+ auto maps = loopToOperandRangesMaps(op);
+ // Iterate over the inputs and outputs in order.
+ // Extract the subranges from the linearized ranges.
+ SmallVector<Value, 8> ios(op.getInputsAndOutputs());
+ for (auto en : llvm::enumerate(ios)) {
+ unsigned idx = en.index();
+ auto map = maps[idx];
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
+ Value view = en.value();
+ SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
+ for (auto en2 : llvm::enumerate(map.getResults())) {
+ if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+ << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
+ << "\n");
+ return ViewDimension{view, static_cast<unsigned>(en2.index())};
+ }
+ }
+ }
+ llvm_unreachable("Expect to be able to extract a view defining loop range");
+}
+
+static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
+ unsigned consumerIdx, unsigned producerIdx,
+ OperationFolder *folder) {
+ auto subView = dyn_cast_or_null<SubViewOp>(
+ consumer.getInput(consumerIdx)->getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(
+ consumer.getInput(consumerIdx)->getDefiningOp());
+ assert(subView || slice);
+ (void)subView;
+ (void)slice;
+
+ // loopToOperandRangesMaps are permutations-only by construction:
+ // we can always identify a data dimension with a (at least one) loop
+ // dimension.
+ AffineMap producerMap =
+ loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
+ LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
+ << ", producer map: " << producerMap << "\n");
+
+ unsigned nPar = producer.getNumParallelLoops();
+ unsigned nRed = producer.getNumReductionLoops();
+ unsigned nWin = producer.getNumWindowLoops();
+ SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
+
+ // Iterate over dimensions identified by the producer map for `producerIdx`.
+ // This defines a subset of the loop ranges that we need to complete later.
+ for (auto en : llvm::enumerate(producerMap.getResults())) {
+ unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+ loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
+ }
+
+ OpBuilder b(consumer.getOperation());
+ auto loc = consumer.getLoc();
+ // Iterate over all dimensions. For the dimensions not identified by the
+ // producer map for `producerIdx`, we need to explicitly compute the view that
+ // defines the loop ranges using the `producer`.
+ for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
+ if (loopRanges[i].offset)
+ LLVM_DEBUG(llvm::dbgs()
+ << "existing LoopRange: " << loopRanges[i] << "\n");
+ else {
+ auto viewDim = getViewDefiningLoopRange(producer, i);
+ loopRanges[i] = SubViewOp::Range{constant_index(folder, 0),
+ dim(viewDim.view, viewDim.dimension),
+ constant_index(folder, 1)};
+ LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
+ }
+ }
+
+ return cloneWithLoopRanges(b, loc, producer, loopRanges);
+}
+
+// Encode structural fusion safety preconditions.
+// Some of these will be lifted in the future with better analysis.
+static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
+ LinalgOp consumer) {
+ if (producer.getNumOutputs() != 1) {
+ LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
+ return false;
+ }
+ // Only fuse when the producer block dominates.
+ DominanceInfo dom(producer.getOperation());
+ if (!dom.dominates(producer.getOperation()->getBlock(),
+ consumer.getOperation()->getBlock())) {
+ LLVM_DEBUG(
+ dbgs()
+ << "\nNot structurally fusable (producer block does not dominate)");
+ return false;
+ }
+ return true;
+}
+
+bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
+ LinalgOp consumer,
+ Value consumedView,
+ LinalgOp producer) {
+ // Make some simple structural checks that alleviate the need for more
+ // complex analyses.
+ if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
+ LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ // Check for any interleaved write to consumedView.
+ if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
+ LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ return true;
+}
+
+bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
+ LinalgOp consumer, Value consumedView,
+ LinalgOp producer) {
+ if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
+ return false;
+ // Check for any fusion-preventing dependence to any view read/written that
+ // would violate dependences.
+ if (!graph.findCoveringDependences(producer, consumer).empty()) {
+ LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ return true;
+}
+
+// Only consider RAW atm.
+Optional<FusionInfo> mlir::linalg::fuseProducerOf(
+ OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
+ const LinalgDependenceGraph &graph, OperationFolder *folder) {
+ LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
+ << *consumer.getOperation());
+ for (auto dependence : graph.getDependencesInto(
+ consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+ LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
+ << *dependence.dependentOpView.op << "\n");
+ auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+
+ // Check that the dependence is indeed on the input `consumerIdx` view.
+ auto consumedView = dependence.indexingView;
+ if (consumer.getInput(consumerIdx) != consumedView)
+ continue;
+
+ // Consumer consumes this view, `isStructurallyFusableProducer` also checks
+ // whether it is a strict subview of the producer view.
+ auto producedView = dependence.dependentOpView.view;
+ auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
+ // `consumerIdx` and `producerIdx` exist by construction.
+ LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
+ << " view: " << *producedView
+ << " output index: " << producerIdx);
+
+ // Must be a subview or a slice to guarantee there are loops we can fuse
+ // into.
+ auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp());
+ if (!subView && !slice) {
+ LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
+ continue;
+ }
+
+ // Simple fusability checks.
+ if (!isFusableInto(graph, consumer, consumedView, producer))
+ continue;
+
+ // Fuse `producer` just before `consumer`.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(consumer.getOperation());
+ ScopedContext scope(b, consumer.getLoc());
+ LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+ auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
+ producerIdx, folder);
+
+ return FusionInfo{producer, fusedProducer};
+ }
+ return llvm::None;
+}
+
+static void fuseLinalgOpsGreedily(FuncOp f) {
+ LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
+
+ OpBuilder b(f);
+ OperationFolder folder(f.getContext());
+ DenseSet<Operation *> eraseSet;
+
+ // Save original Linalg ops, we only want to make a pass over those.
+ SmallVector<Operation *, 8> linalgOps;
+ f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
+
+ Aliases aliases;
+ LinalgDependenceGraph G(aliases, linalgOps);
+ for (auto *op : llvm::reverse(linalgOps)) {
+ for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs();
+ consumerIdx < e; ++consumerIdx) {
+ if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder))
+ eraseSet.insert(fusionInfo->originalProducer.getOperation());
+ }
+ }
+
+ // The `fuseProducerOf` function performs structural checks and in particular
+ // that no covering read or write exist between the consumer and the producer.
+ // As a consequence, the only fusions that may occur preserve subsequent
+ // dependences and are guaranteed by construction to produce the whole view.
+ // We may thus erase the producer once it is fused.
+ for (auto *e : eraseSet)
+ e->erase();
+ LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
+}
+
+namespace {
+struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
+ void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
+};
+} // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgFusionPass() {
+ return std::make_unique<LinalgFusionPass>();
+}
+
+static PassRegistration<LinalgFusionPass>
+ pass("linalg-fusion", "Fuse operations in the linalg dialect");
OpenPOWER on IntegriCloud