summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp64
1 files changed, 47 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 8e7370af7a3..82699545b3f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -201,19 +201,13 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer,
// Encode structural fusion safety preconditions.
// Some of these will be lifted in the future with better analysis.
-static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
+static bool isStructurallyFusableProducer(LinalgOp producer,
+ Value *consumedView,
LinalgOp consumer) {
if (producer.getNumOutputs() != 1) {
LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
return false;
}
- // Must be a subview or a slice to guarantee there are loops we can fuse into.
- auto subView = dyn_cast_or_null<SubViewOp>(readView->getDefiningOp());
- auto slice = dyn_cast_or_null<SliceOp>(readView->getDefiningOp());
- if (!subView && !slice) {
- LLVM_DEBUG(dbgs() << "\nNot structurally fusable (not a subview or slice)");
- return false;
- }
// Only fuse when the producer block dominates.
DominanceInfo dom(producer.getOperation());
if (!dom.dominates(producer.getOperation()->getBlock(),
@@ -226,6 +220,41 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
return true;
}
+bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
+ LinalgOp consumer,
+ Value *consumedView,
+ LinalgOp producer) {
+ // Make some simple structural checks that alleviate the need for more
+ // complex analyses.
+ if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
+ LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ // Check for any interleaved write to consumedView.
+ if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
+ LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ return true;
+}
+
+bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
+ LinalgOp consumer, Value *consumedView,
+ LinalgOp producer) {
+ if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
+ return false;
+ // Check for any fusion-preventing dependence to any view read/written that
+ // would violate dependences.
+ if (!graph.findCoveringDependences(producer, consumer).empty()) {
+ LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
+ << *producer.getOperation());
+ return false;
+ }
+ return true;
+}
+
// Only consider RAW atm.
Optional<FusionInfo> mlir::linalg::fuseProducerOf(
OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
@@ -239,8 +268,8 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
// Check that the dependence is indeed on the input `consumerIdx` view.
- auto *readView = dependence.indexingView;
- if (consumer.getInput(consumerIdx) != readView)
+ auto *consumedView = dependence.indexingView;
+ if (consumer.getInput(consumerIdx) != consumedView)
continue;
// Consumer consumes this view, `isStructurallyFusableProducer` also checks
@@ -252,16 +281,17 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
<< " view: " << *producedView
<< " output index: " << producerIdx);
- // Make some simple structural checks that alleviate the need for more
- // complex analyses.
- if (!isStructurallyFusableProducer(producer, readView, consumer)) {
- LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation());
+ // Must be a subview or a slice to guarantee there are loops we can fuse
+ // into.
+ auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp());
+ if (!subView && !slice) {
+ LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
continue;
}
- // Check for fusion-preventing write that would violate dependences.
- // `view` is a producer write that cannot bypass any other write or read.
- if (!graph.findCoveringDependences(producer, consumer).empty())
+ // Simple fusability checks.
+ if (!isFusableInto(graph, consumer, consumedView, producer))
continue;
// Fuse `producer` just before `consumer`.
OpenPOWER on IntegriCloud