summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-11-01 08:29:42 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-01 08:30:38 -0700
commitbd94a10c02a641e59c5ccfec143f728e13b516c2 (patch)
treed32e22e8224f1fd5a90d804f7ec845917dcb68e8 /mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
parent96531e2f871d74f6bc224446c40b37425d58a5b4 (diff)
downloadbcm5719-llvm-bd94a10c02a641e59c5ccfec143f728e13b516c2.tar.gz
bcm5719-llvm-bd94a10c02a641e59c5ccfec143f728e13b516c2.zip
Add Linalg pattern for producer-consumer fusion
This CL adds a simple pattern for specifying producer-consumer fusion on Linalg operations. Implementing such an extension reveals some interesting properties. Since Linalg operates on a buffer abstraction, the output buffers are specified as in/out parameters to the ops. As a consequence, there are no SSA use-def chains and one cannot specify complex dag input patterns with the current infrastructure. Instead this CL uses constraints based on the existing linalg dependence analysis to focus the pattern and refine patterns based on the type of op that last wrote in a buffer. This is a very local property and is less powerful than the generic dag specification based on SSA use-def chains. This will be generalized in the future. PiperOrigin-RevId: 277931503
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp77
1 files changed, 72 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 118018b9372..aaa7d9dabf6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -26,11 +27,81 @@
#include "mlir/Pass/Pass.h"
using namespace mlir;
-using mlir::linalg::LinalgOp;
+using namespace mlir::linalg;
// Marker used as attribute name in generated Linalg rewriting transformations.
static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__";
+static LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter,
+ Operation *op,
+ ArrayRef<int64_t> sizes,
+ StringRef linalgMarker) {
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+ if (!tileRes)
+ return failure();
+ tileRes->op.setAttr(kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ tileRes->op.getParentOfType<FuncOp>().dump();
+ return success();
+}
+
+static LogicalResult tileAndFuseLinalgOpAndSetMarker(PatternRewriter &rewriter,
+ Operation *op,
+ ArrayRef<int64_t> sizes,
+ StringRef linalgMarker) {
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+ if (!tileRes)
+ return failure();
+ tileRes->op.setAttr(kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ Aliases aliases;
+ auto G = LinalgDependenceGraph::buildDependenceGraph(
+ aliases, op->getParentOfType<FuncOp>());
+ auto fusionRes = fuseProducerOf(rewriter, tileRes->op, 0, G);
+ if (!fusionRes) {
+ // Linalg fusion requires tiled loops to even determine whether it is
+ // possible to fuse. As a consequence, the pattern may fail even though a
+ // tiled version of op has already been introduced.
+ // So we need to remove the tiled version ourselves in case of failure.
+ // Another possibility is to ensure the constraints on the pattern guarantee
+ // that fusion will occur and just assert here.
+ // As we develop more complex patterns we can choose what is best.
+ rewriter.eraseOp(tileRes->loops[0]);
+ return failure();
+ }
+ fusionRes->fusedProducer.setAttr(kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ // The originalProducer can now be safely erased. This is similar to SSA-value
+ // use-def but in the world of buffer + structured ops.
+ rewriter.eraseOp(fusionRes->originalProducer);
+ fusionRes->fusedProducer.getParentOfType<FuncOp>().dump();
+ return success();
+}
+
+template <typename OpTy>
+bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) {
+ LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
+ if (!consumer)
+ return false;
+
+ auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView);
+ if (!maybeConsumerIndex)
+ return false;
+
+ Aliases aliases;
+ auto G = LinalgDependenceGraph::buildDependenceGraph(
+ aliases, consumer.getParentOfType<FuncOp>());
+ for (auto dependence : G.getDependencesInto(
+ consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+ auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+ if (!isProducerLastWriteOfView(G, consumer, consumedView, producer))
+ continue;
+ if (isa<OpTy>(dependence.dependentOpView.op))
+ return true;
+ }
+ return false;
+}
+
namespace mlir {
namespace linalg {
namespace {
@@ -58,10 +129,6 @@ void LinalgTransforms::runOnFunction() {
funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); });
}
-std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgTransformsPass() {
- return std::make_unique<LinalgTransforms>();
-}
-
static PassRegistration<LinalgTransforms>
pass("test-linalg-transform-patterns",
"Test Linalg transformation patterns by applying them greedily.");
OpenPOWER on IntegriCloud