diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-23 08:17:48 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:35:19 -0700 |
| commit | 1301f907a10e25e4a05483977a50c8b4f34b2ed4 (patch) | |
| tree | 5d289f379659c4d458411958cf631e4d22ada678 /mlir/lib/Transforms/PipelineDataTransfer.cpp | |
| parent | 4eef795a1dbd7eafa9a45303f01c51921729f1f4 (diff) | |
| download | bcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.tar.gz bcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.zip | |
Refactor ForStmt: having it contain a StmtBlock instead of subclassing
StmtBlock. This is more consistent with IfStmt and also conceptually makes
more sense - a forstmt "isn't" its body, it contains its body.
This is step 1/N towards merging BasicBlock and StmtBlock. This is required
because in the new regime StmtBlock will have a use list (just like BasicBlock
does) of operands, and ForStmt already has a use list for its induction
variable.
This is a mechanical patch, NFC.
PiperOrigin-RevId: 226684158
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index b656af0d69d..8d75bfbd7ae 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { /// the loop IV of the specified 'for' statement modulo 2. Returns false if such /// a replacement cannot be performed. static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { - MLFuncBuilder bInner(forStmt, forStmt->begin()); - bInner.setInsertionPoint(forStmt, forStmt->begin()); + auto *forBody = forStmt->getBody(); + MLFuncBuilder bInner(forBody, forBody->begin()); + bInner.setInsertionPoint(forBody, forBody->begin()); // Doubles the shape with a leading dimension extent of 2. auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { @@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef), ivModTwoOp->getResult(0), AffineMap::Null(), {}, - &*forStmt->begin())) { + &*forStmt->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getOperation()->erase(); @@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts( // Collect outgoing DMA statements - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; @@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts( } SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; @@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts( cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos())); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forStmt->begin(), *use.getOwner())) { + if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &stmt : *forStmt) { + for (const auto &stmt : *forStmt->getBody()) { if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) { stmtShiftMap[&stmt] = 1; } } // Get shifts stored in map. - std::vector<uint64_t> shifts(forStmt->getStatements().size()); + std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size()); unsigned s = 0; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end()); shifts[s++] = stmtShiftMap[&stmt]; LLVM_DEBUG( |

