diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-12-17 09:58:57 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:30:28 -0700 |
| commit | 4a3e4e8ea7b78db47dee1b46dd37709c0b0ccf02 (patch) | |
| tree | 1db26aef9bd6639f6691ddaf415a55d9d4cae6be /mlir/lib/Transforms/LoopUnroll.cpp | |
| parent | 3b69230b3a7e156150d349d139d4b52172585e50 (diff) | |
| download | bcm5719-llvm-4a3e4e8ea7b78db47dee1b46dd37709c0b0ccf02.tar.gz bcm5719-llvm-4a3e4e8ea7b78db47dee1b46dd37709c0b0ccf02.zip | |
loop-unroll - add function callback argument for outside targets to
provide unroll factors, and a cmd line argument to specify number of
innermost loop unroll repetitions.
- add function callback parameter for outside targets to provide unroll factors
- add a cmd line parameter to repeatedly apply innermost loop unroll a certain
number of times (to avoid using -loop-unroll -loop-unroll ...; instead
-unroll-num-reps=2).
- implement the callback for a target
- update test cases / usage
PiperOrigin-RevId: 225843191
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 70 |
1 files changed, 50 insertions, 20 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 23dff377f71..a43087bd2e1 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -31,17 +31,22 @@ #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" using namespace mlir; // Loop unrolling factor. -static llvm::cl::opt<unsigned> - clUnrollFactor("unroll-factor", llvm::cl::Hidden, - llvm::cl::desc("Use this unroll factor for all loops")); +static llvm::cl::opt<unsigned> clUnrollFactor( + "unroll-factor", llvm::cl::Hidden, + llvm::cl::desc("Use this unroll factor for all loops being unrolled")); static llvm::cl::opt<bool> clUnrollFull("unroll-full", llvm::cl::Hidden, llvm::cl::desc("Fully unroll loops")); +static llvm::cl::opt<unsigned> clUnrollNumRepetitions( + "unroll-num-reps", llvm::cl::Hidden, + llvm::cl::desc("Unroll innermost loops repeatedly this many times")); + static llvm::cl::opt<unsigned> clUnrollFullThreshold( "unroll-full-threshold", llvm::cl::Hidden, llvm::cl::desc( @@ -53,19 +58,25 @@ namespace { /// with trip count less than the specified threshold. The latter is for testing /// purposes, especially for testing outer loop unrolling. struct LoopUnroll : public FunctionPass { - Optional<unsigned> unrollFactor; - Optional<bool> unrollFull; - - explicit LoopUnroll(Optional<unsigned> unrollFactor = None, - Optional<bool> unrollFull = None) + const Optional<unsigned> unrollFactor; + const Optional<bool> unrollFull; + // Callback to obtain unroll factors; if this has a callable target, takes + // precedence over command-line argument or passed argument. + const std::function<unsigned(const ForStmt &)> getUnrollFactor; + + explicit LoopUnroll( + Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, + const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr) : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), - unrollFull(unrollFull) {} + unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} PassResult runOnMLFunction(MLFunction *f) override; /// Unroll this for stmt. Returns false if nothing was done. bool runOnForStmt(ForStmt *forStmt); + static const unsigned kDefaultUnrollFactor = 4; + static char passID; }; } // end anonymous namespace @@ -144,16 +155,33 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) { return success(); } - InnermostLoopGatherer ilg; - ilg.walkPostOrder(f); - auto &loops = ilg.loops; - for (auto *forStmt : loops) - runOnForStmt(forStmt); + unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 + ? clUnrollNumRepetitions + : 1; + // If the call back is provided, we will recurse until no loops are found. + for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { + InnermostLoopGatherer ilg; + ilg.walkPostOrder(f); + auto &loops = ilg.loops; + if (loops.empty()) + break; + bool unrolled = false; + for (auto *forStmt : loops) + unrolled |= runOnForStmt(forStmt); + if (!unrolled) + // Break out if nothing was unrolled. + break; + } return success(); } -/// Unroll a 'for' stmt. Default unroll factor is 4. +/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false +/// otherwise. The default unroll factor is 4. bool LoopUnroll::runOnForStmt(ForStmt *forStmt) { + // Use the function callback if one was provided. + if (getUnrollFactor) { + return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt)); + } // Unroll by the factor passed, if any. if (unrollFactor.hasValue()) return loopUnrollByFactor(forStmt, unrollFactor.getValue()); @@ -166,13 +194,15 @@ bool LoopUnroll::runOnForStmt(ForStmt *forStmt) { return loopUnrollFull(forStmt); // Unroll by four otherwise. - return loopUnrollByFactor(forStmt, 4); + return loopUnrollByFactor(forStmt, kDefaultUnrollFactor); } -FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) { - return new LoopUnroll(unrollFactor == -1 ? None - : Optional<unsigned>(unrollFactor), - unrollFull == -1 ? None : Optional<bool>(unrollFull)); +FunctionPass *mlir::createLoopUnrollPass( + int unrollFactor, int unrollFull, + const std::function<unsigned(const ForStmt &)> &getUnrollFactor) { + return new LoopUnroll( + unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), + unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); } static PassRegistration<LoopUnroll> pass("loop-unroll", "Unroll loops"); |

