summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/DmaGeneration.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/DmaGeneration.cpp')
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp74
1 files changed, 37 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 69344819ed8..bc7f31f0434 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@@ -49,7 +49,7 @@ namespace {
/// buffers in 'fastMemorySpace', and replaces memory operations to the former
/// by the latter. Only load op's handled for now.
/// TODO(bondhugula): extend this to store op's.
-struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
+struct DmaGeneration : public FunctionPass, InstWalker<DmaGeneration> {
explicit DmaGeneration(unsigned slowMemorySpace = 0,
unsigned fastMemorySpaceArg = 1,
int minDmaTransferSize = 1024)
@@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
PassResult runOnMLFunction(Function *f) override;
- void runOnForStmt(ForStmt *forStmt);
+ void runOnForInst(ForInst *forInst);
- void visitOperationInst(OperationInst *opStmt);
- bool generateDma(const MemRefRegion &region, ForStmt *forStmt,
+ void visitOperationInst(OperationInst *opInst);
+ bool generateDma(const MemRefRegion &region, ForInst *forInst,
uint64_t *sizeInBytes);
// List of memory regions to DMA for.
@@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
// Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now.
-void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
- if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
+void DmaGeneration::visitOperationInst(OperationInst *opInst) {
+ if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
- } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
+ } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
} else {
@@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
// This way we would be allocating O(num of memref's) sets instead of
// O(num of load/store op's).
auto region = std::make_unique<MemRefRegion>();
- if (!getMemRefRegion(opStmt, dmaDepth, region.get())) {
+ if (!getMemRefRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n");
return;
}
@@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion &region,
// Creates a buffer in the faster memory space for the specified region;
// generates a DMA from the lower memory space to this one, and replaces all
// loads to load from that buffer. Returns true if DMAs are generated.
-bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
+bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
uint64_t *sizeInBytes) {
// DMAs for read regions are going to be inserted just before the for loop.
- FuncBuilder prologue(forStmt);
+ FuncBuilder prologue(forInst);
// DMAs for write regions are going to be inserted just after the for loop.
- FuncBuilder epilogue(forStmt->getBlock(),
- std::next(Block::iterator(forStmt)));
+ FuncBuilder epilogue(forInst->getBlock(),
+ std::next(Block::iterator(forInst)));
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
- FuncBuilder top(forStmt->getFunction());
+ FuncBuilder top(forInst->getFunction());
- auto loc = forStmt->getLoc();
+ auto loc = forInst->getLoc();
auto *memref = region.memref;
auto memRefType = memref->getType().cast<MemRefType>();
@@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: ");
LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n");
- // Create the fast memory space buffer just before the 'for' statement.
+ // Create the fast memory space buffer just before the 'for' instruction.
fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
@@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
- // *Only* those uses within the body of 'forStmt' are replaced.
+ // *Only* those uses within the body of 'forInst' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domStmtFilter=*/&*forStmt->getBody()->begin());
+ /*domInstFilter=*/&*forInst->getBody()->begin());
return true;
}
-/// Returns the nesting depth of this statement, i.e., the number of loops
-/// surrounding this statement.
+/// Returns the nesting depth of this instruction, i.e., the number of loops
+/// surrounding this instruction.
// TODO(bondhugula): move this to utilities later.
-static unsigned getNestingDepth(const Statement &stmt) {
- const Statement *currStmt = &stmt;
+static unsigned getNestingDepth(const Instruction &inst) {
+ const Instruction *currInst = &inst;
unsigned depth = 0;
- while ((currStmt = currStmt->getParentStmt())) {
- if (isa<ForStmt>(currStmt))
+ while ((currInst = currInst->getParentInst())) {
+ if (isa<ForInst>(currInst))
depth++;
}
return depth;
}
-// TODO(bondhugula): make this run on a Block instead of a 'for' stmt.
-void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
+// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
+void DmaGeneration::runOnForInst(ForInst *forInst) {
// For now (for testing purposes), we'll run this on the outermost among 'for'
- // stmt's with unit stride, i.e., right at the top of the tile if tiling has
+ // inst's with unit stride, i.e., right at the top of the tile if tiling has
// been done. In the future, the DMA generation has to be done at a level
// where the generated data fits in a higher level of the memory hierarchy; so
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
- if (forStmt->getStep() != 1) {
- if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
- runOnForStmt(innerFor);
+ if (forInst->getStep() != 1) {
+ if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) {
+ runOnForInst(innerFor);
}
return;
}
// DMAs will be generated for this depth, i.e., for all data accessed by this
// loop.
- dmaDepth = getNestingDepth(*forStmt);
+ dmaDepth = getNestingDepth(*forInst);
regions.clear();
fastBufferMap.clear();
- // Walk this 'for' statement to gather all memory regions.
- walk(forStmt);
+ // Walk this 'for' instruction to gather all memory regions.
+ walk(forInst);
uint64_t totalSizeInBytes = 0;
bool ret = false;
for (const auto &region : regions) {
uint64_t sizeInBytes;
- bool iRet = generateDma(*region, forStmt, &sizeInBytes);
+ bool iRet = generateDma(*region, forInst, &sizeInBytes);
if (iRet)
totalSizeInBytes += sizeInBytes;
ret = ret | iRet;
@@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
}
PassResult DmaGeneration::runOnMLFunction(Function *f) {
- for (auto &stmt : *f->getBody()) {
- if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
- runOnForStmt(forStmt);
+ for (auto &inst : *f->getBody()) {
+ if (auto *forInst = dyn_cast<ForInst>(&inst)) {
+ runOnForInst(forInst);
}
}
// This function never leaves the IR in an invalid state.
OpenPOWER on IntegriCloud