diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/LoopUtils.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 791997e7ff1..4d75f7c0835 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -119,7 +119,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) { // Move the loop body statements to the loop's containing block. auto *block = forStmt->getBlock(); block->getStatements().splice(StmtBlock::iterator(forStmt), - forStmt->getStatements()); + forStmt->getBody()->getStatements()); forStmt->erase(); return true; } @@ -181,7 +181,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, operandMap[srcForStmt] = loopChunk; } for (auto *stmt : stmts) { - loopChunk->push_back(stmt->clone(operandMap, b->getContext())); + loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext())); } } if (promoteIfSingleIteration(loopChunk)) @@ -206,7 +206,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // method. UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, bool unrollPrologueEpilogue) { - if (forStmt->getStatements().empty()) + if (forStmt->getBody()->empty()) return UtilResult::Success; // If the trip counts aren't constant, we would need versioning and @@ -225,7 +225,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, int64_t step = forStmt->getStep(); - unsigned numChildStmts = forStmt->getStatements().size(); + unsigned numChildStmts = forStmt->getBody()->getStatements().size(); // Do a linear time (counting) sort for the shifts. uint64_t maxShift = 0; @@ -243,7 +243,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts, // body of the 'for' stmt. std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1); unsigned pos = 0; - for (auto &stmt : *forStmt) { + for (auto &stmt : *forStmt->getBody()) { auto shift = shifts[pos++]; sortedStmtGroups[shift].push_back(&stmt); } @@ -352,7 +352,7 @@ bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) { bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { assert(unrollFactor >= 1 && "unroll factor should be >= 1"); - if (unrollFactor == 1 || forStmt->getStatements().empty()) + if (unrollFactor == 1 || forStmt->getBody()->empty()) return false; auto lbMap = forStmt->getLowerBoundMap(); @@ -406,11 +406,11 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { // Builder to insert unrolled bodies right after the last statement in the // body of 'forStmt'. - MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end())); + MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end()); // Keep a pointer to the last statement in the original block so that we know // what to clone (since we are doing this in-place). - StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end()); + StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end()); // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). for (unsigned i = 1; i < unrollFactor; i++) { @@ -429,7 +429,8 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { } // Clone the original body of 'forStmt'. - for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) { + for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd); + it++) { builder.clone(*it, operandMap); } } |

