diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-28 16:05:35 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:44:30 -0700 |
| commit | 456ad6a8e0ca78ce6277da897a0b820533387d84 (patch) | |
| tree | d9fbb26651eed51b02281be03c9bbc66522cbacf /mlir/lib/Transforms/LoopUnroll.cpp | |
| parent | b1d9cc4d1ef5a1f81ca566fc06960df2bf31ddfe (diff) | |
| download | bcm5719-llvm-456ad6a8e0ca78ce6277da897a0b820533387d84.tar.gz bcm5719-llvm-456ad6a8e0ca78ce6277da897a0b820533387d84.zip | |
Standardize naming of statements -> instructions, revisting the code base to be
consistent and moving the using declarations over. Hopefully this is the last
truly massive patch in this refactoring.
This is step 21/n towards merging instructions and statements, NFC.
PiperOrigin-RevId: 227178245
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 68 |
1 files changed, 34 insertions, 34 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 15ea0f841cc..69431bf6349 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -26,7 +26,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/StmtVisitor.h" +#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass { const Optional<bool> unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes // precedence over command-line argument or passed argument. - const std::function<unsigned(const ForStmt &)> getUnrollFactor; + const std::function<unsigned(const ForInst &)> getUnrollFactor; explicit LoopUnroll( Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, - const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr) + const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnMLFunction(Function *f) override; - /// Unroll this for stmt. Returns false if nothing was done. - bool runOnForStmt(ForStmt *forStmt); + /// Unroll this for inst. Returns false if nothing was done. + bool runOnForInst(ForInst *forInst); static const unsigned kDefaultUnrollFactor = 4; @@ -85,13 +85,13 @@ char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnMLFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> { + class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> { public: // Store innermost loops as we walk. - std::vector<ForStmt *> loops; + std::vector<ForInst *> loops; // This method specialized to encode custom return logic. - using InstListType = llvm::iplist<Statement>; + using InstListType = llvm::iplist<Instruction>; bool walkPostOrder(InstListType::iterator Start, InstListType::iterator End) { bool hasInnerLoops = false; @@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { return hasInnerLoops; } - bool walkForStmtPostOrder(ForStmt *forStmt) { + bool walkForInstPostOrder(ForInst *forInst) { bool hasInnerLoops = - walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end()); + walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end()); if (!hasInnerLoops) - loops.push_back(forStmt); + loops.push_back(forInst); return true; } - bool walkIfStmtPostOrder(IfStmt *ifStmt) { + bool walkIfInstPostOrder(IfInst *ifInst) { bool hasInnerLoops = - walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end()); - if (ifStmt->hasElse()) + walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); + if (ifInst->hasElse()) hasInnerLoops |= - walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end()); + walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); return hasInnerLoops; } - bool visitOperationInst(OperationInst *opStmt) { return false; } + bool visitOperationInst(OperationInst *opInst) { return false; } // FIXME: can't use base class method for this because that in turn would // need to use the derived class method above. CRTP doesn't allow it, and // the compiler error resulting from it is also misleading. - using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder; + using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder; }; // Gathers all loops with trip count <= minTripCount. - class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> { + class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> { public: // Store short loops as we walk. - std::vector<ForStmt *> loops; + std::vector<ForInst *> loops; const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitForStmt(ForStmt *forStmt) { - Optional<uint64_t> tripCount = getConstantTripCount(*forStmt); + void visitForInst(ForInst *forInst) { + Optional<uint64_t> tripCount = getConstantTripCount(*forInst); if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) - loops.push_back(forStmt); + loops.push_back(forInst); } }; @@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { // ones). slg.walkPostOrder(f); auto &loops = slg.loops; - for (auto *forStmt : loops) - loopUnrollFull(forStmt); + for (auto *forInst : loops) + loopUnrollFull(forInst); return success(); } @@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { if (loops.empty()) break; bool unrolled = false; - for (auto *forStmt : loops) - unrolled |= runOnForStmt(forStmt); + for (auto *forInst : loops) + unrolled |= runOnForInst(forInst); if (!unrolled) // Break out if nothing was unrolled. break; @@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) { return success(); } -/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false +/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false /// otherwise. The default unroll factor is 4. -bool LoopUnroll::runOnForStmt(ForStmt *forStmt) { +bool LoopUnroll::runOnForInst(ForInst *forInst) { // Use the function callback if one was provided. if (getUnrollFactor) { - return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt)); + return loopUnrollByFactor(forInst, getUnrollFactor(*forInst)); } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) - return loopUnrollByFactor(forStmt, unrollFactor.getValue()); + return loopUnrollByFactor(forInst, unrollFactor.getValue()); // Unroll by the command line factor if one was specified. if (clUnrollFactor.getNumOccurrences() > 0) - return loopUnrollByFactor(forStmt, clUnrollFactor); + return loopUnrollByFactor(forInst, clUnrollFactor); // Unroll completely if full loop unroll was specified. if (clUnrollFull.getNumOccurrences() > 0 || (unrollFull.hasValue() && unrollFull.getValue())) - return loopUnrollFull(forStmt); + return loopUnrollFull(forInst); // Unroll by four otherwise. - return loopUnrollByFactor(forStmt, kDefaultUnrollFactor); + return loopUnrollByFactor(forInst, kDefaultUnrollFactor); } FunctionPass *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, - const std::function<unsigned(const ForStmt &)> &getUnrollFactor) { + const std::function<unsigned(const ForInst &)> &getUnrollFactor) { return new LoopUnroll( unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); |

