summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopTiling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopTiling.cpp')
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp58
1 files changed, 29 insertions, 29 deletions
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 109953f2296..8f3be8a3d45 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -55,16 +55,16 @@ char LoopTiling::passID = 0;
/// Function.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
-// Move the loop body of ForStmt 'src' from 'src' into the specified location in
+// Move the loop body of ForInst 'src' from 'src' into the specified location in
// destination's body.
-static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
+static inline void moveLoopBody(ForInst *src, ForInst *dest,
Block::iterator loc) {
dest->getBody()->getInstructions().splice(loc,
src->getBody()->getInstructions());
}
-// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
-static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
+// Move the loop body of ForInst 'src' from 'src' to the start of dest's body.
+static inline void moveLoopBody(ForInst *src, ForInst *dest) {
moveLoopBody(src, dest, dest->getBody()->begin());
}
@@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
/// depend on other dimensions. Bounds of each dimension can thus be treated
/// independently, and deriving the new bounds is much simpler and faster
/// than for the case of tiling arbitrary polyhedral shapes.
-static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
- ArrayRef<ForStmt *> newLoops,
+static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
+ ArrayRef<ForInst *> newLoops,
ArrayRef<unsigned> tileSizes) {
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
@@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
// TODO(bondhugula): handle non hyper-rectangular spaces.
-UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
+UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
ArrayRef<unsigned> tileSizes) {
assert(!band.empty());
assert(band.size() == tileSizes.size());
- // Check if the supplied for stmt's are all successively nested.
+ // Check if the supplied for inst's are all successively nested.
for (unsigned i = 1, e = band.size(); i < e; i++) {
- assert(band[i]->getParentStmt() == band[i - 1]);
+ assert(band[i]->getParentInst() == band[i - 1]);
}
auto origLoops = band;
- ForStmt *rootForStmt = origLoops[0];
- auto loc = rootForStmt->getLoc();
+ ForInst *rootForInst = origLoops[0];
+ auto loc = rootForInst->getLoc();
// Note that width is at least one since band isn't empty.
unsigned width = band.size();
- SmallVector<ForStmt *, 12> newLoops(2 * width);
- ForStmt *innermostPointLoop;
+ SmallVector<ForInst *, 12> newLoops(2 * width);
+ ForInst *innermostPointLoop;
// The outermost among the loops as we add more..
- auto *topLoop = rootForStmt;
+ auto *topLoop = rootForInst;
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
@@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
- rootForStmt->emitError("tiled code generation unimplemented for the"
+ rootForInst->emitError("tiled code generation unimplemented for the"
"non-hyperrectangular case");
return UtilResult::Failure;
}
@@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
}
// Erase the old loop nest.
- rootForStmt->erase();
+ rootForInst->erase();
return UtilResult::Success;
}
@@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
static void getTileableBands(Function *f,
- std::vector<SmallVector<ForStmt *, 6>> *bands) {
- // Get maximal perfect nest of 'for' stmts starting from root (inclusive).
- auto getMaximalPerfectLoopNest = [&](ForStmt *root) {
- SmallVector<ForStmt *, 6> band;
- ForStmt *currStmt = root;
+ std::vector<SmallVector<ForInst *, 6>> *bands) {
+ // Get maximal perfect nest of 'for' insts starting from root (inclusive).
+ auto getMaximalPerfectLoopNest = [&](ForInst *root) {
+ SmallVector<ForInst *, 6> band;
+ ForInst *currInst = root;
do {
- band.push_back(currStmt);
- } while (currStmt->getBody()->getInstructions().size() == 1 &&
- (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
+ band.push_back(currInst);
+ } while (currInst->getBody()->getInstructions().size() == 1 &&
+ (currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin())));
bands->push_back(band);
};
- for (auto &stmt : *f->getBody()) {
- auto *forStmt = dyn_cast<ForStmt>(&stmt);
- if (!forStmt)
+ for (auto &inst : *f->getBody()) {
+ auto *forInst = dyn_cast<ForInst>(&inst);
+ if (!forInst)
continue;
- getMaximalPerfectLoopNest(forStmt);
+ getMaximalPerfectLoopNest(forInst);
}
}
PassResult LoopTiling::runOnMLFunction(Function *f) {
- std::vector<SmallVector<ForStmt *, 6>> bands;
+ std::vector<SmallVector<ForInst *, 6>> bands;
getTileableBands(f, &bands);
// Temporary tile sizes.
OpenPOWER on IntegriCloud