diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 93 |
1 files changed, 47 insertions, 46 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 5694c990b9b..60f0264eb35 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -172,7 +172,7 @@ public: Node(unsigned id, Operation *op) : id(id), op(op) {} // Returns the load op count for 'memref'. - unsigned getLoadOpCount(Value *memref) { + unsigned getLoadOpCount(ValuePtr memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef()) @@ -182,7 +182,7 @@ public: } // Returns the store op count for 'memref'. - unsigned getStoreOpCount(Value *memref) { + unsigned getStoreOpCount(ValuePtr memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef()) @@ -192,7 +192,7 @@ public: } // Returns all store ops in 'storeOps' which access 'memref'. - void getStoreOpsForMemref(Value *memref, + void getStoreOpsForMemref(ValuePtr memref, SmallVectorImpl<Operation *> *storeOps) { for (auto *storeOpInst : stores) { if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef()) @@ -201,7 +201,7 @@ public: } // Returns all load ops in 'loadOps' which access 'memref'. - void getLoadOpsForMemref(Value *memref, + void getLoadOpsForMemref(ValuePtr memref, SmallVectorImpl<Operation *> *loadOps) { for (auto *loadOpInst : loads) { if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef()) @@ -211,13 +211,13 @@ public: // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node // has at least one load and store operation. - void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) { - llvm::SmallDenseSet<Value *, 2> loadMemrefs; + void getLoadAndStoreMemrefSet(DenseSet<ValuePtr> *loadAndStoreMemrefSet) { + llvm::SmallDenseSet<ValuePtr, 2> loadMemrefs; for (auto *loadOpInst : loads) { loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef()); } for (auto *storeOpInst : stores) { - auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); + auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -239,7 +239,7 @@ public: // defines an SSA value and another graph node which uses the SSA value // (e.g. a constant operation defining a value which is used inside a loop // nest). - Value *value; + ValuePtr value; }; // Map from node id to Node. @@ -250,7 +250,7 @@ public: DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; // Map from memref to a count on the dependence edges associated with that // memref. - DenseMap<Value *, unsigned> memrefEdgeCount; + DenseMap<ValuePtr, unsigned> memrefEdgeCount; // The next unique identifier to use for newly created graph nodes. unsigned nextNodeId = 0; @@ -309,7 +309,7 @@ public: bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); + auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); auto *op = memref->getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) @@ -338,7 +338,7 @@ public: const auto &nodeOutEdges = outEdgeIt->second; for (auto *op : node->stores) { auto storeOp = cast<AffineStoreOp>(op); - auto *memref = storeOp.getMemRef(); + auto memref = storeOp.getMemRef(); // Skip this store if there are no dependences on its memref. This means // that store either: // *) writes to a memref that is only read within the same loop nest @@ -381,7 +381,7 @@ public: // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) { + bool hasEdge(unsigned srcId, unsigned dstId, ValuePtr value = nullptr) { if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { return false; } @@ -395,7 +395,7 @@ public: } // Adds an edge from node 'srcId' to node 'dstId' for 'value'. - void addEdge(unsigned srcId, unsigned dstId, Value *value) { + void addEdge(unsigned srcId, unsigned dstId, ValuePtr value) { if (!hasEdge(srcId, dstId, value)) { outEdges[srcId].push_back({dstId, value}); inEdges[dstId].push_back({srcId, value}); @@ -405,7 +405,7 @@ public: } // Removes an edge from node 'srcId' to node 'dstId' for 'value'. - void removeEdge(unsigned srcId, unsigned dstId, Value *value) { + void removeEdge(unsigned srcId, unsigned dstId, ValuePtr value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); if (value->getType().isa<MemRefType>()) { @@ -459,7 +459,7 @@ public: // Returns the input edge count for node 'id' and 'memref' from src nodes // which access 'memref' with a store operation. - unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) { + unsigned getIncomingMemRefAccesses(unsigned id, ValuePtr memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) @@ -474,7 +474,7 @@ public: // Returns the output edge count for node 'id' and 'memref' (if non-null), // otherwise returns the total output edge count from node 'id'. - unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) { + unsigned getOutEdgeCount(unsigned id, ValuePtr memref = nullptr) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) @@ -548,7 +548,7 @@ public: // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' // has been replaced in node at 'dstId' by a private memref depending // on the value of 'createPrivateMemRef'. - void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef, + void updateEdges(unsigned srcId, unsigned dstId, ValuePtr oldMemRef, bool createPrivateMemRef) { // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. if (inEdges.count(srcId) > 0) { @@ -681,7 +681,7 @@ public: // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(FuncOp f) { - DenseMap<Value *, SetVector<unsigned>> memrefAccesses; + DenseMap<ValuePtr, SetVector<unsigned>> memrefAccesses; // TODO: support multi-block functions. if (f.getBlocks().size() != 1) @@ -701,12 +701,12 @@ bool MemRefDependenceGraph::init(FuncOp f) { Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto *memref = cast<AffineLoadOp>(opInst).getMemRef(); + auto memref = cast<AffineLoadOp>(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto *memref = cast<AffineStoreOp>(opInst).getMemRef(); + auto memref = cast<AffineStoreOp>(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; @@ -715,14 +715,14 @@ bool MemRefDependenceGraph::init(FuncOp f) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); - auto *memref = cast<AffineLoadOp>(op).getMemRef(); + auto memref = cast<AffineLoadOp>(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); - auto *memref = cast<AffineStoreOp>(op).getMemRef(); + auto memref = cast<AffineStoreOp>(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { @@ -743,7 +743,7 @@ bool MemRefDependenceGraph::init(FuncOp f) { if (!node.loads.empty() || !node.stores.empty()) continue; auto *opInst = node.op; - for (auto *value : opInst->getResults()) { + for (auto value : opInst->getResults()) { for (auto *user : value->getUsers()) { SmallVector<AffineForOp, 4> loops; getLoopIVs(*user, &loops); @@ -777,7 +777,7 @@ bool MemRefDependenceGraph::init(FuncOp f) { // Removes load operations from 'srcLoads' which operate on 'memref', and // adds them to 'dstLoads'. -static void moveLoadsAccessingMemrefTo(Value *memref, +static void moveLoadsAccessingMemrefTo(ValuePtr memref, SmallVectorImpl<Operation *> *srcLoads, SmallVectorImpl<Operation *> *dstLoads) { dstLoads->clear(); @@ -893,10 +893,11 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. -static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, - unsigned dstLoopDepth, - Optional<unsigned> fastMemorySpace, - uint64_t localBufSizeThreshold) { +static ValuePtr createPrivateMemRef(AffineForOp forOp, + Operation *srcStoreOpInst, + unsigned dstLoopDepth, + Optional<unsigned> fastMemorySpace, + uint64_t localBufSizeThreshold) { auto *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. @@ -904,7 +905,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Builder to create constants at the top level. OpBuilder top(forInst->getParentOfType<FuncOp>().getBody()); // Create new memref type based on slice bounds. - auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef(); + auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast<MemRefType>(); unsigned rank = oldMemRefType.getRank(); @@ -928,7 +929,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // 'outerIVs' holds the values that this memory region is symbolic/parametric // on; this would correspond to loop IVs surrounding the level at which the // slice is being materialized. - SmallVector<Value *, 8> outerIVs; + SmallVector<ValuePtr, 8> outerIVs; cst->getIdValues(rank, cst->getNumIds(), &outerIVs); // Build 'rank' AffineExprs from MemRefRegion 'lbs' @@ -960,7 +961,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), {}, newMemSpace); // Gather alloc operands for the dynamic dimensions of the memref. - SmallVector<Value *, 4> allocOperands; + SmallVector<ValuePtr, 4> allocOperands; unsigned dynamicDimCount = 0; for (auto dimSize : oldMemRefType.getShape()) { if (dimSize == -1) @@ -973,7 +974,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. - Value *newMemRef = + ValuePtr newMemRef = top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands); // Build an AffineMap to remap access functions based on lower bound offsets. @@ -1016,7 +1017,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, MemRefDependenceGraph *mdg) { assert(srcLiveOutStoreOp && "Expected a valid store op"); auto *dstNode = mdg->getNode(dstId); - Value *memref = srcLiveOutStoreOp.getMemRef(); + ValuePtr memref = srcLiveOutStoreOp.getMemRef(); // Return false if 'srcNode' has more than one output edge on 'memref'. if (mdg->getOutEdgeCount(srcId, memref) > 1) return false; @@ -1495,10 +1496,10 @@ public: SmallVector<Operation *, 4> loads = dstNode->loads; SmallVector<Operation *, 4> dstLoadOpInsts; - DenseSet<Value *> visitedMemrefs; + DenseSet<ValuePtr> visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef(); + auto memref = cast<AffineLoadOp>(loads.back()).getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1653,7 +1654,7 @@ public: } // TODO(andydavis) Use union of memref write regions to compute // private memref footprint. - auto *newMemRef = createPrivateMemRef( + auto newMemRef = createPrivateMemRef( dstAffineForOp, storesForMemref[0], bestDstLoopDepth, fastMemorySpace, localBufSizeThreshold); visitedMemrefs.insert(newMemRef); @@ -1671,7 +1672,7 @@ public: // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef(); + auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef(); if (visitedMemrefs.count(loadMemRef) == 0) loads.push_back(loadOpInst); } @@ -1737,10 +1738,10 @@ public: // Attempt to fuse 'dstNode' with sibling nodes in the graph. void fuseWithSiblingNodes(Node *dstNode) { DenseSet<unsigned> visitedSibNodeIds; - std::pair<unsigned, Value *> idAndMemref; + std::pair<unsigned, ValuePtr> idAndMemref; while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { unsigned sibId = idAndMemref.first; - Value *memref = idAndMemref.second; + ValuePtr memref = idAndMemref.second; // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other // stores to the same memref in 'sibNode' loop nest. auto *sibNode = mdg->getNode(sibId); @@ -1804,10 +1805,10 @@ public: // 'idAndMemrefToFuse' on success. Returns false otherwise. bool findSiblingNodeToFuse(Node *dstNode, DenseSet<unsigned> *visitedSibNodeIds, - std::pair<unsigned, Value *> *idAndMemrefToFuse) { + std::pair<unsigned, ValuePtr> *idAndMemrefToFuse) { // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse // on 'memref'. - auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) { + auto canFuseWithSibNode = [&](Node *sibNode, ValuePtr memref) { // Skip if 'outEdge' is not a read-after-write dependence. // TODO(andydavis) Remove restrict to single load op restriction. if (sibNode->getLoadOpCount(memref) != 1) @@ -1819,15 +1820,15 @@ public: return false; // Skip sib node if it loads to (and stores from) the same memref on // which it also has an input dependence edge. - DenseSet<Value *> loadAndStoreMemrefSet; + DenseSet<ValuePtr> loadAndStoreMemrefSet; sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); - if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { + if (llvm::any_of(loadAndStoreMemrefSet, [=](ValuePtr memref) { return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; })) return false; // Check that all stores are to the same memref. - DenseSet<Value *> storeMemrefs; + DenseSet<ValuePtr> storeMemrefs; for (auto *storeOpInst : sibNode->stores) { storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef()); } @@ -1856,7 +1857,7 @@ public: if (visitedSibNodeIds->count(sibNode->id) > 0) continue; // Skip 'use' if it does not load from the same memref as 'dstNode'. - auto *memref = loadOp.getMemRef(); + auto memref = loadOp.getMemRef(); if (dstNode->getLoadOpCount(memref) == 0) continue; // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. @@ -1950,7 +1951,7 @@ public: for (auto &pair : mdg->memrefEdgeCount) { if (pair.second > 0) continue; - auto *memref = pair.first; + auto memref = pair.first; // Skip if there exist other uses (return operation or function calls). if (!memref->use_empty()) continue; |