diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopTiling.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopTiling.cpp | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 2a4b7bcd262..396fc8eb658 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -103,7 +103,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, auto mayBeConstantCount = getConstantTripCount(*origLoops[i]); // The lower bound is just the tile-space loop. AffineMap lbMap = b.getDimIdentityMap(); - newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap); + newLoops[width + i]->setLowerBound( + /*operands=*/newLoops[i]->getInductionVar(), lbMap); // Set the upper bound. if (mayBeConstantCount.hasValue() && @@ -117,7 +118,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands()); - ubOperands.push_back(newLoops[i]); + ubOperands.push_back(newLoops[i]->getInductionVar()); auto origUbMap = origLoops[i]->getUpperBoundMap(); SmallVector<AffineExpr, 4> boundExprs; @@ -135,7 +136,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops, // No need of the min expression. auto dim = b.getAffineDimExpr(0); auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {}); - newLoops[width + i]->setUpperBound(newLoops[i], ubMap); + newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap); } } } @@ -194,8 +195,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end()); - SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end()); + SmallVector<Value *, 8> origLoopIVs = extractForInductionVars(band); + SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); @@ -208,7 +209,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band, constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes); // In this case, the point loop IVs just replace the original ones. for (unsigned i = 0; i < width; i++) { - origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]); + origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar()); } // Erase the old loop nest. |

