summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/PipelineDataTransfer.cpp
diff options
context:
space:
mode:
authorChris Lattner <clattner@google.com>2018-12-23 08:17:48 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:35:19 -0700
commit1301f907a10e25e4a05483977a50c8b4f34b2ed4 (patch)
tree5d289f379659c4d458411958cf631e4d22ada678 /mlir/lib/Transforms/PipelineDataTransfer.cpp
parent4eef795a1dbd7eafa9a45303f01c51921729f1f4 (diff)
downloadbcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.tar.gz
bcm5719-llvm-1301f907a10e25e4a05483977a50c8b4f34b2ed4.zip
Refactor ForStmt: having it contain a StmtBlock instead of subclassing
StmtBlock. This is more consistent with IfStmt and also conceptually makes more sense - a forstmt "isn't" its body, it contains its body. This is step 1/N towards merging BasicBlock and StmtBlock. This is required because in the new regime StmtBlock will have a use list (just like BasicBlock does) of operands, and ForStmt already has a use list for its induction variable. This is a mechanical patch, NFC. PiperOrigin-RevId: 226684158
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp19
1 files changed, 10 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index b656af0d69d..8d75bfbd7ae 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -81,8 +81,9 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
- MLFuncBuilder bInner(forStmt, forStmt->begin());
- bInner.setInsertionPoint(forStmt, forStmt->begin());
+ auto *forBody = forStmt->getBody();
+ MLFuncBuilder bInner(forBody, forBody->begin());
+ bInner.setInsertionPoint(forBody, forBody->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
@@ -127,7 +128,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
ivModTwoOp->getResult(0), AffineMap::Null(), {},
- &*forStmt->begin())) {
+ &*forStmt->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getOperation()->erase();
@@ -184,7 +185,7 @@ static void findMatchingStartFinishStmts(
// Collect outgoing DMA statements - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
@@ -195,7 +196,7 @@ static void findMatchingStartFinishStmts(
}
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
@@ -228,7 +229,7 @@ static void findMatchingStartFinishStmts(
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forStmt->begin(), *use.getOwner())) {
+ if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@@ -339,16 +340,16 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
- for (const auto &stmt : *forStmt) {
+ for (const auto &stmt : *forStmt->getBody()) {
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
stmtShiftMap[&stmt] = 1;
}
}
// Get shifts stored in map.
- std::vector<uint64_t> shifts(forStmt->getStatements().size());
+ std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size());
unsigned s = 0;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
shifts[s++] = stmtShiftMap[&stmt];
LLVM_DEBUG(
OpenPOWER on IntegriCloud