diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp | 77 |
1 files changed, 72 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 118018b9372..aaa7d9dabf6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -26,11 +27,81 @@ #include "mlir/Pass/Pass.h" using namespace mlir; -using mlir::linalg::LinalgOp; +using namespace mlir::linalg; // Marker used as attribute name in generated Linalg rewriting transformations. static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__"; +static LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, + Operation *op, + ArrayRef<int64_t> sizes, + StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + tileRes->op.getParentOfType<FuncOp>().dump(); + return success(); +} + +static LogicalResult tileAndFuseLinalgOpAndSetMarker(PatternRewriter &rewriter, + Operation *op, + ArrayRef<int64_t> sizes, + StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, op->getParentOfType<FuncOp>()); + auto fusionRes = fuseProducerOf(rewriter, tileRes->op, 0, G); + if (!fusionRes) { + // Linalg fusion requires tiled loops to even determine whether it is + // possible to fuse. As a consequence, the pattern may fail even though a + // tiled version of op has already been introduced. + // So we need to remove the tiled version ourselves in case of failure. + // Another possibility is to ensure the constraints on the pattern guarantee + // that fusion will occur and just assert here. + // As we develop more complex patterns we can choose what is best. + rewriter.eraseOp(tileRes->loops[0]); + return failure(); + } + fusionRes->fusedProducer.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + // The originalProducer can now be safely erased. This is similar to SSA-value + // use-def but in the world of buffer + structured ops. + rewriter.eraseOp(fusionRes->originalProducer); + fusionRes->fusedProducer.getParentOfType<FuncOp>().dump(); + return success(); +} + +template <typename OpTy> +bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { + LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp); + if (!consumer) + return false; + + auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); + if (!maybeConsumerIndex) + return false; + + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, consumer.getParentOfType<FuncOp>()); + for (auto dependence : G.getDependencesInto( + consumer, LinalgDependenceGraph::DependenceType::RAW)) { + auto producer = cast<LinalgOp>(dependence.dependentOpView.op); + if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) + continue; + if (isa<OpTy>(dependence.dependentOpView.op)) + return true; + } + return false; +} + namespace mlir { namespace linalg { namespace { @@ -58,10 +129,6 @@ void LinalgTransforms::runOnFunction() { funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); }); } -std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgTransformsPass() { - return std::make_unique<LinalgTransforms>(); -} - static PassRegistration<LinalgTransforms> pass("test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily."); |