diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopUnrollAndJam.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 86 |
1 files changed, 47 insertions, 39 deletions
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 7deaf850362..7327a37ee3a 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -43,6 +43,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass { unrollJamFactor(unrollJamFactor) {} PassResult runOnFunction(Function *f) override; - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer<AffineForOp> forOp); static char passID; }; @@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // Currently, just the outermost loop from the first loop nest is - // unroll-and-jammed by this pass. However, runOnForInst can be called on any - // for Inst. + // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on + // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front())) - runOnForInst(forInst); + if (auto forOp = + cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>()) + runOnAffineForOp(forOp); return success(); } /// Unroll and jam a 'for' inst. Default unroll jam factor is /// kDefaultUnrollJamFactor. Return false if nothing was done. -bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) { +bool LoopUnrollAndJam::runOnAffineForOp(OpPointer<AffineForOp> forOp) { // Unroll and jam by the factor that was passed if any. if (unrollJamFactor.hasValue()) - return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue()); // Otherwise, unroll jam by the command-line factor if one was specified. if (clUnrollJamFactor.getNumOccurrences() > 0) - return loopUnrollJamByFactor(forInst, clUnrollJamFactor); + return loopUnrollJamByFactor(forOp, clUnrollJamFactor); // Unroll and jam by four otherwise. - return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor); + return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor); } -bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) { - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); +bool mlir::loopUnrollJamUpToFactor(OpPointer<AffineForOp> forOp, + uint64_t unrollJamFactor) { + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() < unrollJamFactor) - return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue()); - return loopUnrollJamByFactor(forInst, unrollJamFactor); + return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue()); + return loopUnrollJamByFactor(forOp, unrollJamFactor); } /// Unrolls and jams this loop by the specified factor. -bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { +bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp, + uint64_t unrollJamFactor) { // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). class JamBlockGatherer : public InstWalker<JamBlockGatherer> { public: using InstListType = llvm::iplist<Instruction>; + using InstWalker<JamBlockGatherer>::walk; // Store iterators to the first and last inst of each sub-block found. std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; @@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !isa<ForInst>(it)) + while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && isa<ForInst>(it)) - walkForInst(cast<ForInst>(it++)); + while (it != End && cast<OperationInst>(it)->isa<AffineForOp>()) + walk(&*it++); } } }; assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); - if (unrollJamFactor == 1 || forInst->getBody()->empty()) + if (unrollJamFactor == 1 || forOp->getBody()->empty()) return false; - Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst); + Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); if (!mayBeConstantTripCount.hasValue() && - getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0) + getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) return false; - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + auto lbMap = forOp->getLowerBoundMap(); + auto ubMap = forOp->getUpperBoundMap(); // Loops with max/min expressions won't be unrolled here (the output can't be // expressed as a Function in the general case). However, the right way to @@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // Same operand list for lower and upper bound for now. // TODO(bondhugula): handle bounds with different sets of operands. - if (!forInst->matchingBoundOperandList()) + if (!forOp->matchingBoundOperandList()) return false; // If the trip count is lower than the unroll jam factor, no unroll jam. @@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { mayBeConstantTripCount.getValue() < unrollJamFactor) return false; + auto *forInst = forOp->getInstruction(); + // Gather all sub-blocks to jam upon the loop being unrolled. JamBlockGatherer jbg; - jbg.walkForInst(forInst); + jbg.walkOpInst(forInst); auto &subBlocks = jbg.subBlocks; // Generate the cleanup loop if trip count isn't a multiple of // unrollJamFactor. if (mayBeConstantTripCount.hasValue() && mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { - // Insert the cleanup loop right after 'forInst'. + // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst)); - cleanupForInst->setLowerBoundMap( - getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder)); + auto cleanupAffineForOp = + cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>(); + cleanupAffineForOp->setLowerBoundMap( + getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); // The upper bound needs to be adjusted. - forInst->setUpperBoundMap( - getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder)); + forOp->setUpperBoundMap( + getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder)); // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(cleanupForInst); + promoteIfSingleIteration(cleanupAffineForOp); } // Scale the step of loop being unroll-jammed by the unroll-jam factor. - int64_t step = forInst->getStep(); - forInst->setStep(step * unrollJamFactor); + int64_t step = forOp->getStep(); + forOp->setStep(step * unrollJamFactor); - auto *forInstIV = forInst->getInductionVar(); + auto *forOpIV = forOp->getInductionVar(); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. @@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInstIV->use_empty()) { + if (!forOpIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); - auto ivUnroll = builder.create<AffineApplyOp>(forInst->getLoc(), - bumpMap, forInstIV); - operandMapping.map(forInstIV, ivUnroll); + auto ivUnroll = + builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV); + operandMapping.map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { @@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { } // Promote the loop body up if this has turned into a single iteration loop. - promoteIfSingleIteration(forInst); + promoteIfSingleIteration(forOp); return true; } |

