diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 64 |
1 files changed, 47 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 8e7370af7a3..82699545b3f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -201,19 +201,13 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. -static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, +static bool isStructurallyFusableProducer(LinalgOp producer, + Value *consumedView, LinalgOp consumer) { if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; } - // Must be a subview or a slice to guarantee there are loops we can fuse into. - auto subView = dyn_cast_or_null<SubViewOp>(readView->getDefiningOp()); - auto slice = dyn_cast_or_null<SliceOp>(readView->getDefiningOp()); - if (!subView && !slice) { - LLVM_DEBUG(dbgs() << "\nNot structurally fusable (not a subview or slice)"); - return false; - } // Only fuse when the producer block dominates. DominanceInfo dom(producer.getOperation()); if (!dom.dominates(producer.getOperation()->getBlock(), @@ -226,6 +220,41 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } +bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, + Value *consumedView, + LinalgOp producer) { + // Make some simple structural checks that alleviate the need for more + // complex analyses. + if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { + LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" + << *producer.getOperation()); + return false; + } + // Check for any interleaved write to consumedView. + if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" + << *producer.getOperation()); + return false; + } + return true; +} + +bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value *consumedView, + LinalgOp producer) { + if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) + return false; + // Check for any fusion-preventing dependence to any view read/written that + // would violate dependences. + if (!graph.findCoveringDependences(producer, consumer).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" + << *producer.getOperation()); + return false; + } + return true; +} + // Only consider RAW atm. Optional<FusionInfo> mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, @@ -239,8 +268,8 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( auto producer = cast<LinalgOp>(dependence.dependentOpView.op); // Check that the dependence is indeed on the input `consumerIdx` view. - auto *readView = dependence.indexingView; - if (consumer.getInput(consumerIdx) != readView) + auto *consumedView = dependence.indexingView; + if (consumer.getInput(consumerIdx) != consumedView) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks @@ -252,16 +281,17 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( << " view: " << *producedView << " output index: " << producerIdx); - // Make some simple structural checks that alleviate the need for more - // complex analyses. - if (!isStructurallyFusableProducer(producer, readView, consumer)) { - LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); + // Must be a subview or a slice to guarantee there are loops we can fuse + // into. + auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp()); + auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp()); + if (!subView && !slice) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); continue; } - // Check for fusion-preventing write that would violate dependences. - // `view` is a producer write that cannot bypass any other write or read. - if (!graph.findCoveringDependences(producer, consumer).empty()) + // Simple fusability checks. + if (!isFusableInto(graph, consumer, consumedView, producer)) continue; // Fuse `producer` just before `consumer`. |