diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 67 |
1 files changed, 35 insertions, 32 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d..86e913bd71f 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -21,6 +21,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass { const Optional<bool> unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function<unsigned(const ForInst &)> getUnrollFactor; + const std::function<unsigned(ConstOpPointer<AffineForOp>)> getUnrollFactor; - explicit LoopUnroll( - Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, - const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr) + explicit LoopUnroll(Optional<unsigned> unrollFactor = None, + Optional<bool> unrollFull = None, + const std::function<unsigned(ConstOpPointer<AffineForOp>)> + &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnFunction(Function *f) override; /// Unroll this for inst. Returns false if nothing was done. - bool runOnForInst(ForInst *forInst); + bool runOnAffineForOp(OpPointer<AffineForOp> forOp); static const unsigned kDefaultUnrollFactor = 4; @@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> { public: // Store innermost loops as we walk. - std::vector<ForInst *> loops; + std::vector<OpPointer<AffineForOp>> loops; // This method specialized to encode custom return logic. using InstListType = llvm::iplist<Instruction>; @@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkForInstPostOrder(ForInst *forInst) { - bool hasInnerLoops = - walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); - if (!hasInnerLoops) - loops.push_back(forInst); - return true; - } - bool walkOpInstPostOrder(OperationInst *opInst) { + bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) - if (walkPostOrder(block.begin(), block.end())) - return true; - return false; + hasInnerLoops |= walkPostOrder(block.begin(), block.end()); + if (opInst->isa<AffineForOp>()) { + if (!hasInnerLoops) + loops.push_back(opInst->cast<AffineForOp>()); + return true; + } + return hasInnerLoops; } // FIXME: can't use base class method for this because that in turn would @@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) { class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> { public: // Store short loops as we walk. - std::vector<ForInst *> loops; + std::vector<OpPointer<AffineForOp>> loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForInst(ForInst *forInst) { - Optional<uint64_t> tripCount = getConstantTripCount(*forInst); + void visitOperationInst(OperationInst *opInst) { + auto forOp = opInst->dyn_cast<AffineForOp>(); + if (!forOp) + return; + Optional<uint64_t> tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forInst); + loops.push_back(forOp); } }; @@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forInst : loops) - loopUnrollFull(forInst); + for (auto forOp : loops) + loopUnrollFull(forOp); return success(); } @@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forInst : loops) - unrolled |= runOnForInst(forInst); + for (auto forOp : loops) + unrolled |= runOnAffineForOp(forOp); if (!unrolled) // Break out if nothing was unrolled. break; @@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) { /// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForInst(ForInst *forInst) { +bool LoopUnroll::runOnAffineForOp(OpPointer<AffineForOp> forOp) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); + return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forInst, unrollFactor.getValue()); + return loopUnrollByFactor(forOp, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forInst, clUnrollFactor); + return loopUnrollByFactor(forOp, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forInst); + return loopUnrollFull(forOp); // Unroll by four otherwise. - return loopUnrollByFactor(forInst, kDefaultUnrollFactor); + return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function<unsigned(const ForInst &)> &getUnrollFactor) { + const std::function<unsigned(ConstOpPointer<AffineForOp>)> + &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); |

