summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp77
1 files changed, 72 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 118018b9372..aaa7d9dabf6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -26,11 +27,81 @@
#include "mlir/Pass/Pass.h"
using namespace mlir;
-using mlir::linalg::LinalgOp;
+using namespace mlir::linalg;
// Marker used as attribute name in generated Linalg rewriting transformations.
static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__";
+static LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter,
+ Operation *op,
+ ArrayRef<int64_t> sizes,
+ StringRef linalgMarker) {
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+ if (!tileRes)
+ return failure();
+ tileRes->op.setAttr(kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ tileRes->op.getParentOfType<FuncOp>().dump();
+ return success();
+}
+
+static LogicalResult tileAndFuseLinalgOpAndSetMarker(PatternRewriter &rewriter,
+ Operation *op,
+ ArrayRef<int64_t> sizes,
+ StringRef linalgMarker) {
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+ if (!tileRes)
+ return failure();
+ tileRes->op.setAttr(kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ Aliases aliases;
+ auto G = LinalgDependenceGraph::buildDependenceGraph(
+ aliases, op->getParentOfType<FuncOp>());
+ auto fusionRes = fuseProducerOf(rewriter, tileRes->op, 0, G);
+ if (!fusionRes) {
+ // Linalg fusion requires tiled loops to even determine whether it is
+ // possible to fuse. As a consequence, the pattern may fail even though a
+ // tiled version of op has already been introduced.
+ // So we need to remove the tiled version ourselves in case of failure.
+ // Another possibility is to ensure the constraints on the pattern guarantee
+ // that fusion will occur and just assert here.
+ // As we develop more complex patterns we can choose what is best.
+ rewriter.eraseOp(tileRes->loops[0]);
+ return failure();
+ }
+ fusionRes->fusedProducer.setAttr(kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
+ // The originalProducer can now be safely erased. This is similar to SSA-value
+ // use-def but in the world of buffer + structured ops.
+ rewriter.eraseOp(fusionRes->originalProducer);
+ fusionRes->fusedProducer.getParentOfType<FuncOp>().dump();
+ return success();
+}
+
+template <typename OpTy>
+bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) {
+ LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
+ if (!consumer)
+ return false;
+
+ auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView);
+ if (!maybeConsumerIndex)
+ return false;
+
+ Aliases aliases;
+ auto G = LinalgDependenceGraph::buildDependenceGraph(
+ aliases, consumer.getParentOfType<FuncOp>());
+ for (auto dependence : G.getDependencesInto(
+ consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+ auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+ if (!isProducerLastWriteOfView(G, consumer, consumedView, producer))
+ continue;
+ if (isa<OpTy>(dependence.dependentOpView.op))
+ return true;
+ }
+ return false;
+}
+
namespace mlir {
namespace linalg {
namespace {
@@ -58,10 +129,6 @@ void LinalgTransforms::runOnFunction() {
funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); });
}
-std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgTransformsPass() {
- return std::make_unique<LinalgTransforms>();
-}
-
static PassRegistration<LinalgTransforms>
pass("test-linalg-transform-patterns",
"Test Linalg transformation patterns by applying them greedily.");
OpenPOWER on IntegriCloud