diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopTiling.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopTiling.cpp | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 85c88f785d1..847db83aebc 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -59,12 +59,12 @@ FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } // destination's body. static inline void moveLoopBody(ForStmt *src, ForStmt *dest, StmtBlock::iterator loc) { - dest->getStatements().splice(loc, src->getStatements()); + dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements()); } // Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { - moveLoopBody(src, dest, dest->begin()); + moveLoopBody(src, dest, dest->getBody()->begin()); } /// Constructs and sets new loop bounds after tiling for the case of @@ -167,8 +167,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, MLFuncBuilder b(topLoop); // Loop bounds will be set later. auto *pointLoop = b.createFor(loc, 0, 0); - pointLoop->getStatements().splice( - pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop); + pointLoop->getBody()->getStatements().splice( + pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), + topLoop); newLoops[2 * width - 1 - i] = pointLoop; topLoop = pointLoop; if (i == 0) @@ -180,8 +181,9 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, MLFuncBuilder b(topLoop); // Loop bounds will be set later. auto *tileSpaceLoop = b.createFor(loc, 0, 0); - tileSpaceLoop->getStatements().splice( - tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop); + tileSpaceLoop->getBody()->getStatements().splice( + tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(), + topLoop); newLoops[2 * width - i - 1] = tileSpaceLoop; topLoop = tileSpaceLoop; } @@ -223,8 +225,8 @@ static void getTileableBands(MLFunction *f, ForStmt *currStmt = root; do { band.push_back(currStmt); - } while (currStmt->getStatements().size() == 1 && - (currStmt = dyn_cast<ForStmt>(&*currStmt->begin()))); + } while (currStmt->getBody()->getStatements().size() == 1 && + (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin()))); bands->push_back(band); }; |

