diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopTiling.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopTiling.cpp | 58 |
1 files changed, 29 insertions, 29 deletions
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 109953f2296..8f3be8a3d45 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -55,16 +55,16 @@ char LoopTiling::passID = 0; /// Function. FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } -// Move the loop body of ForStmt 'src' from 'src' into the specified location in +// Move the loop body of ForInst 'src' from 'src' into the specified location in // destination's body. -static inline void moveLoopBody(ForStmt *src, ForStmt *dest, +static inline void moveLoopBody(ForInst *src, ForInst *dest, Block::iterator loc) { dest->getBody()->getInstructions().splice(loc, src->getBody()->getInstructions()); } -// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body. -static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { +// Move the loop body of ForInst 'src' from 'src' to the start of dest's body. +static inline void moveLoopBody(ForInst *src, ForInst *dest) { moveLoopBody(src, dest, dest->getBody()->begin()); } @@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) { /// depend on other dimensions. Bounds of each dimension can thus be treated /// independently, and deriving the new bounds is much simpler and faster /// than for the case of tiling arbitrary polyhedral shapes. -static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops, - ArrayRef<ForStmt *> newLoops, +static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, + ArrayRef<ForInst *> newLoops, ArrayRef<unsigned> tileSizes) { assert(!origLoops.empty()); assert(origLoops.size() == tileSizes.size()); @@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops, /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO(bondhugula): handle non hyper-rectangular spaces. -UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, +UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, ArrayRef<unsigned> tileSizes) { assert(!band.empty()); assert(band.size() == tileSizes.size()); - // Check if the supplied for stmt's are all successively nested. + // Check if the supplied for inst's are all successively nested. for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i]->getParentStmt() == band[i - 1]); + assert(band[i]->getParentInst() == band[i - 1]); } auto origLoops = band; - ForStmt *rootForStmt = origLoops[0]; - auto loc = rootForStmt->getLoc(); + ForInst *rootForInst = origLoops[0]; + auto loc = rootForInst->getLoc(); // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector<ForStmt *, 12> newLoops(2 * width); - ForStmt *innermostPointLoop; + SmallVector<ForInst *, 12> newLoops(2 * width); + ForInst *innermostPointLoop; // The outermost among the loops as we add more.. - auto *topLoop = rootForStmt; + auto *topLoop = rootForInst; // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { @@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootForStmt->emitError("tiled code generation unimplemented for the" + rootForInst->emitError("tiled code generation unimplemented for the" "non-hyperrectangular case"); return UtilResult::Failure; } @@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, } // Erase the old loop nest. - rootForStmt->erase(); + rootForInst->erase(); return UtilResult::Success; } @@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band, // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. static void getTileableBands(Function *f, - std::vector<SmallVector<ForStmt *, 6>> *bands) { - // Get maximal perfect nest of 'for' stmts starting from root (inclusive). - auto getMaximalPerfectLoopNest = [&](ForStmt *root) { - SmallVector<ForStmt *, 6> band; - ForStmt *currStmt = root; + std::vector<SmallVector<ForInst *, 6>> *bands) { + // Get maximal perfect nest of 'for' insts starting from root (inclusive). + auto getMaximalPerfectLoopNest = [&](ForInst *root) { + SmallVector<ForInst *, 6> band; + ForInst *currInst = root; do { - band.push_back(currStmt); - } while (currStmt->getBody()->getInstructions().size() == 1 && - (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin()))); + band.push_back(currInst); + } while (currInst->getBody()->getInstructions().size() == 1 && + (currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin()))); bands->push_back(band); }; - for (auto &stmt : *f->getBody()) { - auto *forStmt = dyn_cast<ForStmt>(&stmt); - if (!forStmt) + for (auto &inst : *f->getBody()) { + auto *forInst = dyn_cast<ForInst>(&inst); + if (!forInst) continue; - getMaximalPerfectLoopNest(forStmt); + getMaximalPerfectLoopNest(forInst); } } PassResult LoopTiling::runOnMLFunction(Function *f) { - std::vector<SmallVector<ForStmt *, 6>> bands; + std::vector<SmallVector<ForInst *, 6>> bands; getTileableBands(f, &bands); // Temporary tile sizes. |

