summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopFusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp180
1 files changed, 90 insertions, 90 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index d31337437ad..97dea753f88 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -27,7 +27,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@@ -80,20 +80,20 @@ char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
-static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt,
+static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst,
MemRefAccess *access) {
- if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
+ if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
access->memref = loadOp->getMemRef();
- access->opStmt = loadOrStoreOpStmt;
+ access->opInst = loadOrStoreOpInst;
auto loadMemrefType = loadOp->getMemRefType();
access->indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp->getIndices()) {
access->indices.push_back(index);
}
} else {
- assert(loadOrStoreOpStmt->isa<StoreOp>());
- auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
- access->opStmt = loadOrStoreOpStmt;
+ assert(loadOrStoreOpInst->isa<StoreOp>());
+ auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
+ access->opInst = loadOrStoreOpInst;
access->memref = storeOp->getMemRef();
auto storeMemrefType = storeOp->getMemRefType();
access->indices.reserve(storeMemrefType.getRank());
@@ -112,24 +112,24 @@ struct FusionCandidate {
MemRefAccess dstAccess;
};
-static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt,
- OperationInst *dstLoadOpStmt) {
+static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst,
+ OperationInst *dstLoadOpInst) {
FusionCandidate candidate;
// Get store access for src loop nest.
- getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
+ getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess);
// Get load access for dst loop nest.
- getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
+ getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess);
return candidate;
}
-// Returns the loop depth of the loop nest surrounding 'opStmt'.
-static unsigned getLoopDepth(OperationInst *opStmt) {
+// Returns the loop depth of the loop nest surrounding 'opInst'.
+static unsigned getLoopDepth(OperationInst *opInst) {
unsigned loopDepth = 0;
- auto *currStmt = opStmt->getParentStmt();
- ForStmt *currForStmt;
- while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) {
+ auto *currInst = opInst->getParentInst();
+ ForInst *currForInst;
+ while (currInst && (currForInst = dyn_cast<ForInst>(currInst))) {
++loopDepth;
- currStmt = currStmt->getParentStmt();
+ currInst = currInst->getParentInst();
}
return loopDepth;
}
@@ -137,28 +137,28 @@ static unsigned getLoopDepth(OperationInst *opStmt) {
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not an IfStmt was encountered in the loop nest.
-class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
+// operations, and whether or not an IfInst was encountered in the loop nest.
+class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
- SmallVector<ForStmt *, 4> forStmts;
- SmallVector<OperationInst *, 4> loadOpStmts;
- SmallVector<OperationInst *, 4> storeOpStmts;
- bool hasIfStmt = false;
+ SmallVector<ForInst *, 4> forInsts;
+ SmallVector<OperationInst *, 4> loadOpInsts;
+ SmallVector<OperationInst *, 4> storeOpInsts;
+ bool hasIfInst = false;
- void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
+ void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
- void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
+ void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
- void visitOperationInst(OperationInst *opStmt) {
- if (opStmt->isa<LoadOp>())
- loadOpStmts.push_back(opStmt);
- if (opStmt->isa<StoreOp>())
- storeOpStmts.push_back(opStmt);
+ void visitOperationInst(OperationInst *opInst) {
+ if (opInst->isa<LoadOp>())
+ loadOpInsts.push_back(opInst);
+ if (opInst->isa<StoreOp>())
+ storeOpInsts.push_back(opInst);
}
};
// MemRefDependenceGraph is a graph data structure where graph nodes are
-// top-level statements in a Function which contain load/store ops, and edges
+// top-level instructions in a Function which contain load/store ops, and edges
// are memref dependences between the nodes.
// TODO(andydavis) Add a depth parameter to dependence graph construction.
struct MemRefDependenceGraph {
@@ -170,18 +170,18 @@ public:
// The unique identifier of this node in the graph.
unsigned id;
// The top-level statment which is (or contains) loads/stores.
- Statement *stmt;
+ Instruction *inst;
// List of load operations.
SmallVector<OperationInst *, 4> loads;
- // List of store op stmts.
+ // List of store op insts.
SmallVector<OperationInst *, 4> stores;
- Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
+ Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
// Returns the load op count for 'memref'.
unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0;
- for (auto *loadOpStmt : loads) {
- if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
+ for (auto *loadOpInst : loads) {
+ if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
++loadOpCount;
}
return loadOpCount;
@@ -190,8 +190,8 @@ public:
// Returns the store op count for 'memref'.
unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0;
- for (auto *storeOpStmt : stores) {
- if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
+ for (auto *storeOpInst : stores) {
+ if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
++storeOpCount;
}
return storeOpCount;
@@ -315,10 +315,10 @@ public:
void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
const SmallVectorImpl<OperationInst *> &stores) {
Node *node = getNode(id);
- for (auto *loadOpStmt : loads)
- node->loads.push_back(loadOpStmt);
- for (auto *storeOpStmt : stores)
- node->stores.push_back(storeOpStmt);
+ for (auto *loadOpInst : loads)
+ node->loads.push_back(loadOpInst);
+ for (auto *storeOpInst : stores)
+ node->stores.push_back(storeOpInst);
}
void print(raw_ostream &os) const {
@@ -341,55 +341,55 @@ public:
void dump() const { print(llvm::errs()); }
};
-// Intializes the data dependence graph by walking statements in 'f'.
+// Intializes the data dependence graph by walking instructions in 'f'.
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(Function *f) {
unsigned id = 0;
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
- for (auto &stmt : *f->getBody()) {
- if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
- // Create graph node 'id' to represent top-level 'forStmt' and record
+ for (auto &inst : *f->getBody()) {
+ if (auto *forInst = dyn_cast<ForInst>(&inst)) {
+ // Create graph node 'id' to represent top-level 'forInst' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
- collector.walkForStmt(forStmt);
- // Return false if IfStmts are found (not currently supported).
- if (collector.hasIfStmt)
+ collector.walkForInst(forInst);
+ // Return false if IfInsts are found (not currently supported).
+ if (collector.hasIfInst)
return false;
- Node node(id++, &stmt);
- for (auto *opStmt : collector.loadOpStmts) {
- node.loads.push_back(opStmt);
- auto *memref = opStmt->cast<LoadOp>()->getMemRef();
+ Node node(id++, &inst);
+ for (auto *opInst : collector.loadOpInsts) {
+ node.loads.push_back(opInst);
+ auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
- for (auto *opStmt : collector.storeOpStmts) {
- node.stores.push_back(opStmt);
- auto *memref = opStmt->cast<StoreOp>()->getMemRef();
+ for (auto *opInst : collector.storeOpInsts) {
+ node.stores.push_back(opInst);
+ auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
nodes.insert({node.id, node});
}
- if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
- if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
+ if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
- Node node(id++, &stmt);
- node.loads.push_back(opStmt);
- auto *memref = opStmt->cast<LoadOp>()->getMemRef();
+ Node node(id++, &inst);
+ node.loads.push_back(opInst);
+ auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
- if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
+ if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
- Node node(id++, &stmt);
- node.stores.push_back(opStmt);
- auto *memref = opStmt->cast<StoreOp>()->getMemRef();
+ Node node(id++, &inst);
+ node.stores.push_back(opInst);
+ auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
}
- // Return false if IfStmts are found (not currently supported).
- if (isa<IfStmt>(&stmt))
+ // Return false if IfInsts are found (not currently supported).
+ if (isa<IfInst>(&inst))
return false;
}
@@ -421,9 +421,9 @@ bool MemRefDependenceGraph::init(Function *f) {
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
-// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
-// destination ForStmt into which fusion will be attempted.
-// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
+// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
+// destination ForInst into which fusion will be attempted.
+// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
@@ -434,12 +434,12 @@ bool MemRefDependenceGraph::init(Function *f) {
// bounds to be functions of 'dstLoopNest' IVs and symbols.
// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
// just before the dst load op user.
-// *) Add the newly fused load/store operation statements to the state,
+// *) Add the newly fused load/store operation instructions to the state,
// and also add newly fuse load ops to 'dstLoopOps' to be considered
// as fusion dst load ops in another iteration.
// *) Remove old src loop nest and its associated state.
//
-// Given a graph where top-level statements are vertices in the set 'V' and
+// Given a graph where top-level instructions are vertices in the set 'V' and
// edges in the set 'E' are dependences between vertices, this algorithm
// takes O(V) time for initialization, and has runtime O(V + E).
//
@@ -471,14 +471,14 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
- if (!isa<ForStmt>(dstNode->stmt))
+ if (!isa<ForInst>(dstNode->inst))
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
while (!loads.empty()) {
- auto *dstLoadOpStmt = loads.pop_back_val();
- auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
- // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
+ auto *dstLoadOpInst = loads.pop_back_val();
+ auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
+ // Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
if (dstNode->getLoadOpCount(memref) != 1)
continue;
// Skip if no input edges along which to fuse.
@@ -491,7 +491,7 @@ public:
continue;
auto *srcNode = mdg->getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
- if (!isa<ForStmt>(srcNode->stmt))
+ if (!isa<ForInst>(srcNode->inst))
continue;
// Skip if 'srcNode' has more than one store to 'memref'.
if (srcNode->getStoreOpCount(memref) != 1)
@@ -508,17 +508,17 @@ public:
if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId)
continue;
// Get unique 'srcNode' store op.
- auto *srcStoreOpStmt = srcNode->stores.front();
- // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
+ auto *srcStoreOpInst = srcNode->stores.front();
+ // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
FusionCandidate candidate =
- buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
+ buildFusionCandidate(srcStoreOpInst, dstLoadOpInst);
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0
? clSrcLoopDepth
- : getLoopDepth(srcStoreOpStmt);
+ : getLoopDepth(srcStoreOpInst);
unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0
? clDstLoopDepth
- : getLoopDepth(dstLoadOpStmt);
+ : getLoopDepth(dstLoadOpInst);
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
&candidate.srcAccess, &candidate.dstAccess, srcLoopDepth,
dstLoopDepth);
@@ -527,19 +527,19 @@ public:
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
// Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
LoopNestStateCollector collector;
- collector.walkForStmt(sliceLoopNest);
- mdg->addToNode(dstId, collector.loadOpStmts,
- collector.storeOpStmts);
+ collector.walkForInst(sliceLoopNest);
+ mdg->addToNode(dstId, collector.loadOpInsts,
+ collector.storeOpInsts);
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
- for (auto *loadOpStmt : collector.loadOpStmts)
- loads.push_back(loadOpStmt);
+ for (auto *loadOpInst : collector.loadOpInsts)
+ loads.push_back(loadOpInst);
// Promote single iteration loops to single IV value.
- for (auto *forStmt : collector.forStmts) {
- promoteIfSingleIteration(forStmt);
+ for (auto *forInst : collector.forInsts) {
+ promoteIfSingleIteration(forInst);
}
// Remove old src loop nest.
- cast<ForStmt>(srcNode->stmt)->erase();
+ cast<ForInst>(srcNode->inst)->erase();
}
}
}
OpenPOWER on IntegriCloud