diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/LoopUtils.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 117 |
1 files changed, 117 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 2603ea8d806..55cb64b8ba9 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -314,3 +314,120 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays, return UtilResult::Success; } + +/// Unrolls this loop completely. +bool mlir::loopUnrollFull(ForStmt *forStmt) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); + if (mayBeConstantTripCount.hasValue()) { + uint64_t tripCount = mayBeConstantTripCount.getValue(); + if (tripCount == 1) { + return promoteIfSingleIteration(forStmt); + } + return loopUnrollByFactor(forStmt, tripCount); + } + return false; +} + +/// Unrolls and jams this loop by the specified factor or by the trip count (if +/// constant) whichever is lower. +bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); + + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollFactor) + return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue()); + return loopUnrollByFactor(forStmt, unrollFactor); +} + +/// Unrolls this loop by the specified factor. Returns true if the loop +/// is successfully unrolled. +bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) { + assert(unrollFactor >= 1 && "unroll factor should be >= 1"); + + if (unrollFactor == 1 || forStmt->getStatements().empty()) + return false; + + auto lbMap = forStmt->getLowerBoundMap(); + auto ubMap = forStmt->getUpperBoundMap(); + + // Loops with max/min expressions won't be unrolled here (the output can't be + // expressed as an MLFunction in the general case). However, the right way to + // do such unrolling for an MLFunction would be to specialize the loop for the + // 'hotspot' case and unroll that hotspot. + if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1) + return false; + + // Same operand list for lower and upper bound for now. + // TODO(bondhugula): handle bounds with different operand lists. + if (!forStmt->matchingBoundOperandList()) + return false; + + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); + + // If the trip count is lower than the unroll factor, no unrolled body. + // TODO(bondhugula): option to specify cleanup loop unrolling. + if (mayBeConstantTripCount.hasValue() && + mayBeConstantTripCount.getValue() < unrollFactor) + return false; + + // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. + if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { + DenseMap<const MLValue *, MLValue *> operandMap; + MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); + auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); + auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); + assert(clLbMap && + "cleanup loop lower bound map for single result bound maps can " + "always be determined"); + cleanupForStmt->setLowerBoundMap(clLbMap); + // Promote the loop body up if this has turned into a single iteration loop. + promoteIfSingleIteration(cleanupForStmt); + + // Adjust upper bound. + auto unrolledUbMap = + getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder); + assert(unrolledUbMap && + "upper bound map can alwayys be determined for an unrolled loop " + "with single result bounds"); + forStmt->setUpperBoundMap(unrolledUbMap); + } + + // Scale the step of loop being unrolled by unroll factor. + int64_t step = forStmt->getStep(); + forStmt->setStep(step * unrollFactor); + + // Builder to insert unrolled bodies right after the last statement in the + // body of 'forStmt'. + MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end())); + + // Keep a pointer to the last statement in the original block so that we know + // what to clone (since we are doing this in-place). + StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end()); + + // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). + for (unsigned i = 1; i < unrollFactor; i++) { + DenseMap<const MLValue *, MLValue *> operandMap; + + // If the induction variable is used, create a remapping to the value for + // this unrolled instance. + if (!forStmt->use_empty()) { + // iv' = iv + 1/2/3...unrollFactor-1; + auto d0 = builder.getAffineDimExpr(0); + auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); + auto *ivUnroll = + builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) + ->getResult(0); + operandMap[forStmt] = cast<MLValue>(ivUnroll); + } + + // Clone the original body of 'forStmt'. + for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) { + builder.clone(*it, operandMap); + } + } + + // Promote the loop body up if this has turned into a single iteration loop. + promoteIfSingleIteration(forStmt); + + return true; +} |

