diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopUnrollAndJam.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 100 |
1 files changed, 50 insertions, 50 deletions
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 60e8d154f98..f59659cf234 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -40,7 +40,7 @@ // S6(i+1); // // Note: 'if/else' blocks are not jammed. So, if there are loops inside if -// stmt's, bodies of those loops will not be jammed. +// inst's, bodies of those loops will not be jammed. //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" @@ -49,7 +49,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnMLFunction(Function *f) override; - bool runOnForStmt(ForStmt *forStmt); + bool runOnForInst(ForInst *forInst); static char passID; }; @@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForStmt can be called on any - // for Stmt. - auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin()); - if (!forStmt) + // unroll-and-jammed by this pass. However, runOnForInst can be called on any + // for Inst. + auto *forInst = dyn_cast<ForInst>(f->getBody()->begin()); + if (!forInst) return success(); - runOnForStmt(forStmt); + runOnForInst(forInst); return success(); } -/// Unroll and jam a 'for' stmt. Default unroll jam factor is +/// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) { +bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forStmt, clUnrollJamFactor); + return loopUnrollJamByFactor(forInst, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); +bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forStmt, unrollJamFactor); + return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forInst, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { - // Gathers all maximal sub-blocks of statements that do not themselves include - // a for stmt (a statement could have a descendant for stmt though in its - // tree). - class JamBlockGatherer : public StmtWalker<JamBlockGatherer> { +bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { + // Gathers all maximal sub-blocks of instructions that do not themselves + // include a for inst (a instruction could have a descendant for inst though + // in its tree). + class JamBlockGatherer : public InstWalker<JamBlockGatherer> { public: - using InstListType = llvm::iplist<Statement>; + using InstListType = llvm::iplist<Instruction>; - // Store iterators to the first and last stmt of each sub-block found. + // Store iterators to the first and last inst of each sub-block found. std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; // This is a linear time walk. void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa<ForStmt>(it)) + while (it != End && !isa<ForInst>(it)) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); - // Process all for stmts that appear next. - while (it != End && isa<ForStmt>(it)) - walkForStmt(cast<ForStmt>(it++)); + // Process all for insts that appear next. + while (it != End && isa<ForInst>(it)) + walkForInst(cast<ForInst>(it++)); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forStmt->getBody()->empty()) + if (unrollJamFactor == 1 || forInst->getBody()->empty()) return false; - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt); + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) return false; - auto lbMap = forStmt->getLowerBoundMap(); - auto ubMap = forStmt->getUpperBoundMap(); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forStmt->matchingBoundOperandList()) + if (!forInst->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForStmt(forStmt); + jbg.walkForInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of @@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { DenseMap<const Value *, Value *> operandMap; - // Insert the cleanup loop right after 'forStmt'. - FuncBuilder builder(forStmt->getBlock(), - std::next(Block::iterator(forStmt))); - auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); - cleanupForStmt->setLowerBoundMap( - getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder)); + // Insert the cleanup loop right after 'forInst'. + FuncBuilder builder(forInst->getBlock(), + std::next(Block::iterator(forInst))); + auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap)); + cleanupForInst->setLowerBoundMap( + getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forStmt->setUpperBoundMap( - getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder)); + forInst->setUpperBoundMap( + getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForStmt); + promoteIfSingleIteration(cleanupForInst); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forStmt->getStep(); - forStmt->setStep(step * unrollJamFactor); + int64_t step = forInst->getStep(); + forInst->setStep(step * unrollJamFactor); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of @@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forStmt->use_empty()) { + if (!forInst->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) + builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst) ->getResult(0); - operandMapping[forStmt] = ivUnroll; + operandMapping[forInst] = ivUnroll; } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forStmt); + promoteIfSingleIteration(forInst); return true; } |

