summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp117
1 files changed, 117 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 2603ea8d806..55cb64b8ba9 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -314,3 +314,120 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
return UtilResult::Success;
}
+
+/// Unrolls this loop completely.
+bool mlir::loopUnrollFull(ForStmt *forStmt) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+ if (mayBeConstantTripCount.hasValue()) {
+ uint64_t tripCount = mayBeConstantTripCount.getValue();
+ if (tripCount == 1) {
+ return promoteIfSingleIteration(forStmt);
+ }
+ return loopUnrollByFactor(forStmt, tripCount);
+ }
+ return false;
+}
+
+/// Unrolls and jams this loop by the specified factor or by the trip count (if
+/// constant) whichever is lower.
+bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollFactor)
+ return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
+ return loopUnrollByFactor(forStmt, unrollFactor);
+}
+
+/// Unrolls this loop by the specified factor. Returns true if the loop
+/// is successfully unrolled.
+bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
+ assert(unrollFactor >= 1 && "unroll factor should be >= 1");
+
+ if (unrollFactor == 1 || forStmt->getStatements().empty())
+ return false;
+
+ auto lbMap = forStmt->getLowerBoundMap();
+ auto ubMap = forStmt->getUpperBoundMap();
+
+ // Loops with max/min expressions won't be unrolled here (the output can't be
+ // expressed as an MLFunction in the general case). However, the right way to
+ // do such unrolling for an MLFunction would be to specialize the loop for the
+ // 'hotspot' case and unroll that hotspot.
+ if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
+ return false;
+
+ // Same operand list for lower and upper bound for now.
+ // TODO(bondhugula): handle bounds with different operand lists.
+ if (!forStmt->matchingBoundOperandList())
+ return false;
+
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+
+ // If the trip count is lower than the unroll factor, no unrolled body.
+ // TODO(bondhugula): option to specify cleanup loop unrolling.
+ if (mayBeConstantTripCount.hasValue() &&
+ mayBeConstantTripCount.getValue() < unrollFactor)
+ return false;
+
+ // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+ if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
+ DenseMap<const MLValue *, MLValue *> operandMap;
+ MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
+ auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
+ auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
+ assert(clLbMap &&
+ "cleanup loop lower bound map for single result bound maps can "
+ "always be determined");
+ cleanupForStmt->setLowerBoundMap(clLbMap);
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(cleanupForStmt);
+
+ // Adjust upper bound.
+ auto unrolledUbMap =
+ getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
+ assert(unrolledUbMap &&
+ "upper bound map can alwayys be determined for an unrolled loop "
+ "with single result bounds");
+ forStmt->setUpperBoundMap(unrolledUbMap);
+ }
+
+ // Scale the step of loop being unrolled by unroll factor.
+ int64_t step = forStmt->getStep();
+ forStmt->setStep(step * unrollFactor);
+
+ // Builder to insert unrolled bodies right after the last statement in the
+ // body of 'forStmt'.
+ MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+
+ // Keep a pointer to the last statement in the original block so that we know
+ // what to clone (since we are doing this in-place).
+ StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end());
+
+ // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
+ for (unsigned i = 1; i < unrollFactor; i++) {
+ DenseMap<const MLValue *, MLValue *> operandMap;
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forStmt->use_empty()) {
+ // iv' = iv + 1/2/3...unrollFactor-1;
+ auto d0 = builder.getAffineDimExpr(0);
+ auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
+ auto *ivUnroll =
+ builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
+ ->getResult(0);
+ operandMap[forStmt] = cast<MLValue>(ivUnroll);
+ }
+
+ // Clone the original body of 'forStmt'.
+ for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) {
+ builder.clone(*it, operandMap);
+ }
+ }
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(forStmt);
+
+ return true;
+}
OpenPOWER on IntegriCloud