summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp117
1 files changed, 96 insertions, 21 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index cf0f07345a4..aebf2716c4e 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -177,6 +177,15 @@ public:
}
return storeOpCount;
}
+
+ // Returns all store ups in 'storeOps' which access 'memref'.
+ void getStoreOpsForMemref(Value *memref,
+ SmallVectorImpl<Instruction *> *storeOps) {
+ for (auto *storeOpInst : stores) {
+ if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
+ storeOps->push_back(storeOpInst);
+ }
+ }
};
// Edge represents a data dependece between nodes in the graph.
@@ -258,10 +267,10 @@ public:
for (auto *storeOpInst : node->stores) {
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
auto *inst = memref->getDefiningInst();
- // Return false if 'memref' is a block argument.
+ // Return true if 'memref' is a block argument.
if (!inst)
return true;
- // Return false if any use of 'memref' escapes the function.
+ // Return true if any use of 'memref' escapes the function.
for (auto &use : memref->getUses())
if (!isMemRefDereferencingOp(*use.getOwner()))
return true;
@@ -1157,6 +1166,63 @@ static uint64_t getSliceIterationCount(
return iterCount;
}
+// Checks if node 'srcId' (which writes to a live out memref), can be safely
+// fused into node 'dstId'. Returns true if the following conditions are met:
+// *) 'srcNode' writes only writes to live out 'memref'.
+// *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId').
+// *) 'dstNode' does write to 'memref'.
+// *) 'dstNode's write region to 'memref' is a super set of 'srcNode's write
+// region to 'memref'.
+// TODO(andydavis) Generalize this to handle more live in/out cases.
+static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+ Value *memref,
+ MemRefDependenceGraph *mdg) {
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+
+ // Return false if any of the following are true:
+ // *) 'srcNode' writes to a live in/out memref other than 'memref'.
+ // *) 'srcNode' has more than one output edge on 'memref'.
+ // *) 'dstNode' does not write to 'memref'.
+ if (srcNode->getStoreOpCount(memref) != 1 ||
+ mdg->getOutEdgeCount(srcNode->id, memref) != 1 ||
+ dstNode->getStoreOpCount(memref) == 0)
+ return false;
+ // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
+ auto *srcStoreOpInst = srcNode->stores.front();
+ MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
+ srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
+ SmallVector<int64_t, 4> srcShape;
+ // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
+ // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+ Optional<int64_t> srcNumElements =
+ srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
+ if (!srcNumElements.hasValue())
+ return false;
+
+ // Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'.
+ SmallVector<Instruction *, 2> dstStoreOps;
+ dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
+ assert(dstStoreOps.size() == 1);
+ auto *dstStoreOpInst = dstStoreOps[0];
+ MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
+ dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0);
+ SmallVector<int64_t, 4> dstShape;
+ // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'.
+ // by 'dstStoreOpInst' at depth 'dstLoopDepth'.
+ Optional<int64_t> dstNumElements =
+ dstWriteRegion.getConstantBoundingSizeAndShape(&dstShape);
+ if (!dstNumElements.hasValue())
+ return false;
+
+ // Return false if write region is not a superset of 'srcNodes' write
+ // region to 'memref'.
+ // TODO(andydavis) Check the shape and lower bounds here too.
+ if (srcNumElements != dstNumElements)
+ return false;
+ return true;
+}
+
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
// Returns true if it is profitable to fuse the candidate loop nests. Returns
@@ -1593,8 +1659,12 @@ public:
if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) != 0)
continue;
- // Skip if 'srcNode' writes to any live in or escaping memrefs.
- if (mdg->writesToLiveInOrEscapingMemrefs(srcNode->id))
+ // Skip if 'srcNode' writes to any live in or escaping memrefs,
+ // and cannot be fused.
+ bool writesToLiveInOrOut =
+ mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
+ if (writesToLiveInOrOut &&
+ !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
continue;
// Compute an instruction list insertion point for the fused loop
@@ -1639,22 +1709,24 @@ public:
for (auto forOp : sliceCollector.forOps) {
promoteIfSingleIteration(forOp);
}
- // Create private memref for 'memref' in 'dstAffineForOp'.
- SmallVector<Instruction *, 4> storesForMemref;
- for (auto *storeOpInst : sliceCollector.storeOpInsts) {
- if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
- storesForMemref.push_back(storeOpInst);
+ if (!writesToLiveInOrOut) {
+ // Create private memref for 'memref' in 'dstAffineForOp'.
+ SmallVector<Instruction *, 4> storesForMemref;
+ for (auto *storeOpInst : sliceCollector.storeOpInsts) {
+ if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
+ storesForMemref.push_back(storeOpInst);
+ }
+ assert(storesForMemref.size() == 1);
+ auto *newMemRef = createPrivateMemRef(
+ dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+ fastMemorySpace, localBufSizeThreshold);
+ visitedMemrefs.insert(newMemRef);
+ // Create new node in dependence graph for 'newMemRef' alloc op.
+ unsigned newMemRefNodeId =
+ mdg->addNode(newMemRef->getDefiningInst());
+ // Add edge from 'newMemRef' node to dstNode.
+ mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
}
- assert(storesForMemref.size() == 1);
- auto *newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
- visitedMemrefs.insert(newMemRef);
- // Create new node in dependence graph for 'newMemRef' alloc op.
- unsigned newMemRefNodeId =
- mdg->addNode(newMemRef->getDefiningInst());
- // Add edge from 'newMemRef' node to dstNode.
- mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
// Collect dst loop stats after memref privatizaton transformation.
LoopNestStateCollector dstLoopCollector;
@@ -1674,8 +1746,11 @@ public:
dstLoopCollector.storeOpInsts);
// Remove old src loop nest if it no longer has outgoing dependence
// edges, and it does not write to a memref which escapes the
- // function.
- if (mdg->canRemoveNode(srcNode->id)) {
+ // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
+ // been fused into 'dstNode' and write region of 'dstNode' covers
+ // the write region of 'srcNode', and 'srcNode' has no other users
+ // so it is safe to remove.
+ if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
mdg->removeNode(srcNode->id);
srcNode->inst->erase();
} else {
OpenPOWER on IntegriCloud