diff options
| author | MLIR Team <no-reply@google.com> | 2019-02-15 17:12:19 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:30:11 -0700 |
| commit | 58aa383e6092cd56dc3b5586b82b53674912dfad (patch) | |
| tree | 0e78684a051b86e86408fd99264e4843e63d7f59 /mlir/lib | |
| parent | ecd403c0e80d049ba8b5ccbf30eb80dd465e8d8b (diff) | |
| download | bcm5719-llvm-58aa383e6092cd56dc3b5586b82b53674912dfad.tar.gz bcm5719-llvm-58aa383e6092cd56dc3b5586b82b53674912dfad.zip | |
Support fusing producer loop nests which write to a memref which is live out, provided that the write region of the consumer loop nest to the same memref is a super set of the producer's write region.
PiperOrigin-RevId: 234240958
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 117 |
1 files changed, 96 insertions, 21 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cf0f07345a4..aebf2716c4e 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -177,6 +177,15 @@ public: } return storeOpCount; } + + // Returns all store ups in 'storeOps' which access 'memref'. + void getStoreOpsForMemref(Value *memref, + SmallVectorImpl<Instruction *> *storeOps) { + for (auto *storeOpInst : stores) { + if (memref == storeOpInst->cast<StoreOp>()->getMemRef()) + storeOps->push_back(storeOpInst); + } + } }; // Edge represents a data dependece between nodes in the graph. @@ -258,10 +267,10 @@ public: for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast<StoreOp>()->getMemRef(); auto *inst = memref->getDefiningInst(); - // Return false if 'memref' is a block argument. + // Return true if 'memref' is a block argument. if (!inst) return true; - // Return false if any use of 'memref' escapes the function. + // Return true if any use of 'memref' escapes the function. for (auto &use : memref->getUses()) if (!isMemRefDereferencingOp(*use.getOwner())) return true; @@ -1157,6 +1166,63 @@ static uint64_t getSliceIterationCount( return iterCount; } +// Checks if node 'srcId' (which writes to a live out memref), can be safely +// fused into node 'dstId'. Returns true if the following conditions are met: +// *) 'srcNode' writes only writes to live out 'memref'. +// *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId'). +// *) 'dstNode' does write to 'memref'. +// *) 'dstNode's write region to 'memref' is a super set of 'srcNode's write +// region to 'memref'. +// TODO(andydavis) Generalize this to handle more live in/out cases. +static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, + Value *memref, + MemRefDependenceGraph *mdg) { + auto *srcNode = mdg->getNode(srcId); + auto *dstNode = mdg->getNode(dstId); + + // Return false if any of the following are true: + // *) 'srcNode' writes to a live in/out memref other than 'memref'. + // *) 'srcNode' has more than one output edge on 'memref'. + // *) 'dstNode' does not write to 'memref'. + if (srcNode->getStoreOpCount(memref) != 1 || + mdg->getOutEdgeCount(srcNode->id, memref) != 1 || + dstNode->getStoreOpCount(memref) == 0) + return false; + // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. + auto *srcStoreOpInst = srcNode->stores.front(); + MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); + srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0); + SmallVector<int64_t, 4> srcShape; + // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. + // by 'srcStoreOpInst' at depth 'dstLoopDepth'. + Optional<int64_t> srcNumElements = + srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); + if (!srcNumElements.hasValue()) + return false; + + // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'. + SmallVector<Instruction *, 2> dstStoreOps; + dstNode->getStoreOpsForMemref(memref, &dstStoreOps); + assert(dstStoreOps.size() == 1); + auto *dstStoreOpInst = dstStoreOps[0]; + MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc()); + dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0); + SmallVector<int64_t, 4> dstShape; + // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'. + // by 'dstStoreOpInst' at depth 'dstLoopDepth'. + Optional<int64_t> dstNumElements = + dstWriteRegion.getConstantBoundingSizeAndShape(&dstShape); + if (!dstNumElements.hasValue()) + return false; + + // Return false if write region is not a superset of 'srcNodes' write + // region to 'memref'. + // TODO(andydavis) Check the shape and lower bounds here too. + if (srcNumElements != dstNumElements) + return false; + return true; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // Returns true if it is profitable to fuse the candidate loop nests. Returns @@ -1593,8 +1659,12 @@ public: if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0) continue; - // Skip if 'srcNode' writes to any live in or escaping memrefs. - if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) + // Skip if 'srcNode' writes to any live in or escaping memrefs, + // and cannot be fused. + bool writesToLiveInOrOut = + mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); + if (writesToLiveInOrOut && + !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg)) continue; // Compute an instruction list insertion point for the fused loop @@ -1639,22 +1709,24 @@ public: for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); } - // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector<Instruction *, 4> storesForMemref; - for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (storeOpInst->cast<StoreOp>()->getMemRef() == memref) - storesForMemref.push_back(storeOpInst); + if (!writesToLiveInOrOut) { + // Create private memref for 'memref' in 'dstAffineForOp'. + SmallVector<Instruction *, 4> storesForMemref; + for (auto *storeOpInst : sliceCollector.storeOpInsts) { + if (storeOpInst->cast<StoreOp>()->getMemRef() == memref) + storesForMemref.push_back(storeOpInst); + } + assert(storesForMemref.size() == 1); + auto *newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = + mdg->addNode(newMemRef->getDefiningInst()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } - assert(storesForMemref.size() == 1); - auto *newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = - mdg->addNode(newMemRef->getDefiningInst()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; @@ -1674,8 +1746,11 @@ public: dstLoopCollector.storeOpInsts); // Remove old src loop nest if it no longer has outgoing dependence // edges, and it does not write to a memref which escapes the - // function. - if (mdg->canRemoveNode(srcNode->id)) { + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has + // been fused into 'dstNode' and write region of 'dstNode' covers + // the write region of 'srcNode', and 'srcNode' has no other users + // so it is safe to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); srcNode->inst->erase(); } else { |

