diff options
author | Nicolas Vasilache <ntv@google.com> | 2019-11-01 08:29:42 -0700 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-01 08:30:38 -0700 |
commit | bd94a10c02a641e59c5ccfec143f728e13b516c2 (patch) | |
tree | d32e22e8224f1fd5a90d804f7ec845917dcb68e8 /mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | |
parent | 96531e2f871d74f6bc224446c40b37425d58a5b4 (diff) | |
download | bcm5719-llvm-bd94a10c02a641e59c5ccfec143f728e13b516c2.tar.gz bcm5719-llvm-bd94a10c02a641e59c5ccfec143f728e13b516c2.zip |
Add Linalg pattern for producer-consumer fusion
This CL adds a simple pattern for specifying producer-consumer fusion on Linalg operations.
Implementing such an extension reveals some interesting properties.
Since Linalg operates on a buffer abstraction, the output buffers are specified as in/out parameters to the ops. As a consequence, there are no SSA use-def chains and one cannot specify complex dag input patterns with the current infrastructure.
Instead this CL uses constraints based on the existing linalg dependence analysis to focus the pattern and refine patterns based on the type of op that last wrote in a buffer.
This is a very local property and is less powerful than the generic dag specification based on SSA use-def chains.
This will be generalized in the future.
PiperOrigin-RevId: 277931503
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 64 |
1 files changed, 47 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 8e7370af7a3..82699545b3f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -201,19 +201,13 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. -static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, +static bool isStructurallyFusableProducer(LinalgOp producer, + Value *consumedView, LinalgOp consumer) { if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; } - // Must be a subview or a slice to guarantee there are loops we can fuse into. - auto subView = dyn_cast_or_null<SubViewOp>(readView->getDefiningOp()); - auto slice = dyn_cast_or_null<SliceOp>(readView->getDefiningOp()); - if (!subView && !slice) { - LLVM_DEBUG(dbgs() << "\nNot structurally fusable (not a subview or slice)"); - return false; - } // Only fuse when the producer block dominates. DominanceInfo dom(producer.getOperation()); if (!dom.dominates(producer.getOperation()->getBlock(), @@ -226,6 +220,41 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } +bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, + Value *consumedView, + LinalgOp producer) { + // Make some simple structural checks that alleviate the need for more + // complex analyses. + if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { + LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" + << *producer.getOperation()); + return false; + } + // Check for any interleaved write to consumedView. + if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" + << *producer.getOperation()); + return false; + } + return true; +} + +bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value *consumedView, + LinalgOp producer) { + if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) + return false; + // Check for any fusion-preventing dependence to any view read/written that + // would violate dependences. + if (!graph.findCoveringDependences(producer, consumer).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" + << *producer.getOperation()); + return false; + } + return true; +} + // Only consider RAW atm. Optional<FusionInfo> mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, @@ -239,8 +268,8 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( auto producer = cast<LinalgOp>(dependence.dependentOpView.op); // Check that the dependence is indeed on the input `consumerIdx` view. - auto *readView = dependence.indexingView; - if (consumer.getInput(consumerIdx) != readView) + auto *consumedView = dependence.indexingView; + if (consumer.getInput(consumerIdx) != consumedView) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks @@ -252,16 +281,17 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf( << " view: " << *producedView << " output index: " << producerIdx); - // Make some simple structural checks that alleviate the need for more - // complex analyses. - if (!isStructurallyFusableProducer(producer, readView, consumer)) { - LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); + // Must be a subview or a slice to guarantee there are loops we can fuse + // into. + auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp()); + auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp()); + if (!subView && !slice) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); continue; } - // Check for fusion-preventing write that would violate dependences. - // `view` is a producer write that cannot bypass any other write or read. - if (!graph.findCoveringDependences(producer, consumer).empty()) + // Simple fusability checks. + if (!isFusableInto(graph, consumer, consumedView, producer)) continue; // Fuse `producer` just before `consumer`. |