summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopTiling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopTiling.cpp')
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 2a4b7bcd262..396fc8eb658 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -103,7 +103,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
// The lower bound is just the tile-space loop.
AffineMap lbMap = b.getDimIdentityMap();
- newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap);
+ newLoops[width + i]->setLowerBound(
+ /*operands=*/newLoops[i]->getInductionVar(), lbMap);
// Set the upper bound.
if (mayBeConstantCount.hasValue() &&
@@ -117,7 +118,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
// with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended.
SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
- ubOperands.push_back(newLoops[i]);
+ ubOperands.push_back(newLoops[i]->getInductionVar());
auto origUbMap = origLoops[i]->getUpperBoundMap();
SmallVector<AffineExpr, 4> boundExprs;
@@ -135,7 +136,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
// No need of the min expression.
auto dim = b.getAffineDimExpr(0);
auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
- newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
+ newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap);
}
}
}
@@ -194,8 +195,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
// Move the loop body of the original nest to the new one.
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
- SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
- SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
+ SmallVector<Value *, 8> origLoopIVs = extractForInductionVars(band);
+ SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
FlatAffineConstraints cst;
getIndexSet(band, &cst);
@@ -208,7 +209,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
// In this case, the point loop IVs just replace the original ones.
for (unsigned i = 0; i < width; i++) {
- origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
+ origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar());
}
// Erase the old loop nest.
OpenPOWER on IntegriCloud