diff options
Diffstat (limited to 'mlir/lib/Transforms/DmaGeneration.cpp')
| -rw-r--r-- | mlir/lib/Transforms/DmaGeneration.cpp | 74 |
1 files changed, 37 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 69344819ed8..bc7f31f0434 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -25,7 +25,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -49,7 +49,7 @@ namespace { /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { +struct DmaGeneration : public FunctionPass, InstWalker<DmaGeneration> { explicit DmaGeneration(unsigned slowMemorySpace = 0, unsigned fastMemorySpaceArg = 1, int minDmaTransferSize = 1024) @@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { // Not applicable to CFG functions. PassResult runOnCFGFunction(Function *f) override { return success(); } PassResult runOnMLFunction(Function *f) override; - void runOnForStmt(ForStmt *forStmt); + void runOnForInst(ForInst *forInst); - void visitOperationInst(OperationInst *opStmt); - bool generateDma(const MemRefRegion ®ion, ForStmt *forStmt, + void visitOperationInst(OperationInst *opInst); + bool generateDma(const MemRefRegion ®ion, ForInst *forInst, uint64_t *sizeInBytes); // List of memory regions to DMA for. @@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. -void DmaGeneration::visitOperationInst(OperationInst *opStmt) { - if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { +void DmaGeneration::visitOperationInst(OperationInst *opInst) { + if (auto loadOp = opInst->dyn_cast<LoadOp>()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; - } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; } else { @@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) { // This way we would be allocating O(num of memref's) sets instead of // O(num of load/store op's). auto region = std::make_unique<MemRefRegion>(); - if (!getMemRefRegion(opStmt, dmaDepth, region.get())) { + if (!getMemRefRegion(opInst, dmaDepth, region.get())) { LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); return; } @@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, // Creates a buffer in the faster memory space for the specified region; // generates a DMA from the lower memory space to this one, and replaces all // loads to load from that buffer. Returns true if DMAs are generated. -bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, +bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, uint64_t *sizeInBytes) { // DMAs for read regions are going to be inserted just before the for loop. - FuncBuilder prologue(forStmt); + FuncBuilder prologue(forInst); // DMAs for write regions are going to be inserted just after the for loop. - FuncBuilder epilogue(forStmt->getBlock(), - std::next(Block::iterator(forStmt))); + FuncBuilder epilogue(forInst->getBlock(), + std::next(Block::iterator(forInst))); FuncBuilder *b = region.isWrite() ? &epilogue : &prologue; // Builder to create constants at the top level. - FuncBuilder top(forStmt->getFunction()); + FuncBuilder top(forInst->getFunction()); - auto loc = forStmt->getLoc(); + auto loc = forInst->getLoc(); auto *memref = region.memref; auto memRefType = memref->getType().cast<MemRefType>(); @@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: "); LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n"); - // Create the fast memory space buffer just before the 'for' statement. + // Create the fast memory space buffer just before the 'for' instruction. fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult(); // Record it. fastBufferMap[memref] = fastMemRef; @@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt, remapExprs.push_back(dimExpr - offsets[i]); } auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); - // *Only* those uses within the body of 'forStmt' are replaced. + // *Only* those uses within the body of 'forInst' are replaced. replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*domStmtFilter=*/&*forStmt->getBody()->begin()); + /*domInstFilter=*/&*forInst->getBody()->begin()); return true; } -/// Returns the nesting depth of this statement, i.e., the number of loops -/// surrounding this statement. +/// Returns the nesting depth of this instruction, i.e., the number of loops +/// surrounding this instruction. // TODO(bondhugula): move this to utilities later. -static unsigned getNestingDepth(const Statement &stmt) { - const Statement *currStmt = &stmt; +static unsigned getNestingDepth(const Instruction &inst) { + const Instruction *currInst = &inst; unsigned depth = 0; - while ((currStmt = currStmt->getParentStmt())) { - if (isa<ForStmt>(currStmt)) + while ((currInst = currInst->getParentInst())) { + if (isa<ForInst>(currInst)) depth++; } return depth; } -// TODO(bondhugula): make this run on a Block instead of a 'for' stmt. -void DmaGeneration::runOnForStmt(ForStmt *forStmt) { +// TODO(bondhugula): make this run on a Block instead of a 'for' inst. +void DmaGeneration::runOnForInst(ForInst *forInst) { // For now (for testing purposes), we'll run this on the outermost among 'for' - // stmt's with unit stride, i.e., right at the top of the tile if tiling has + // inst's with unit stride, i.e., right at the top of the tile if tiling has // been done. In the future, the DMA generation has to be done at a level // where the generated data fits in a higher level of the memory hierarchy; so // the pass has to be instantiated with additional information that we aren't // provided with at the moment. - if (forStmt->getStep() != 1) { - if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) { - runOnForStmt(innerFor); + if (forInst->getStep() != 1) { + if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) { + runOnForInst(innerFor); } return; } // DMAs will be generated for this depth, i.e., for all data accessed by this // loop. - dmaDepth = getNestingDepth(*forStmt); + dmaDepth = getNestingDepth(*forInst); regions.clear(); fastBufferMap.clear(); - // Walk this 'for' statement to gather all memory regions. - walk(forStmt); + // Walk this 'for' instruction to gather all memory regions. + walk(forInst); uint64_t totalSizeInBytes = 0; bool ret = false; for (const auto ®ion : regions) { uint64_t sizeInBytes; - bool iRet = generateDma(*region, forStmt, &sizeInBytes); + bool iRet = generateDma(*region, forInst, &sizeInBytes); if (iRet) totalSizeInBytes += sizeInBytes; ret = ret | iRet; @@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) { } PassResult DmaGeneration::runOnMLFunction(Function *f) { - for (auto &stmt : *f->getBody()) { - if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { - runOnForStmt(forStmt); + for (auto &inst : *f->getBody()) { + if (auto *forInst = dyn_cast<ForInst>(&inst)) { + runOnForInst(forInst); } } // This function never leaves the IR in an invalid state. |

