summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnroll.cpp
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-12-17 09:58:57 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:30:28 -0700
commit4a3e4e8ea7b78db47dee1b46dd37709c0b0ccf02 (patch)
tree1db26aef9bd6639f6691ddaf415a55d9d4cae6be /mlir/lib/Transforms/LoopUnroll.cpp
parent3b69230b3a7e156150d349d139d4b52172585e50 (diff)
downloadbcm5719-llvm-4a3e4e8ea7b78db47dee1b46dd37709c0b0ccf02.tar.gz
bcm5719-llvm-4a3e4e8ea7b78db47dee1b46dd37709c0b0ccf02.zip
loop-unroll - add function callback argument for outside targets to
provide unroll factors, and a cmd line argument to specify number of innermost loop unroll repetitions. - add function callback parameter for outside targets to provide unroll factors - add a cmd line parameter to repeatedly apply innermost loop unroll a certain number of times (to avoid using -loop-unroll -loop-unroll ...; instead -unroll-num-reps=2). - implement the callback for a target - update test cases / usage PiperOrigin-RevId: 225843191
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp70
1 files changed, 50 insertions, 20 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 23dff377f71..a43087bd2e1 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -31,17 +31,22 @@
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
// Loop unrolling factor.
-static llvm::cl::opt<unsigned>
- clUnrollFactor("unroll-factor", llvm::cl::Hidden,
- llvm::cl::desc("Use this unroll factor for all loops"));
+static llvm::cl::opt<unsigned> clUnrollFactor(
+ "unroll-factor", llvm::cl::Hidden,
+ llvm::cl::desc("Use this unroll factor for all loops being unrolled"));
static llvm::cl::opt<bool> clUnrollFull("unroll-full", llvm::cl::Hidden,
llvm::cl::desc("Fully unroll loops"));
+static llvm::cl::opt<unsigned> clUnrollNumRepetitions(
+ "unroll-num-reps", llvm::cl::Hidden,
+ llvm::cl::desc("Unroll innermost loops repeatedly this many times"));
+
static llvm::cl::opt<unsigned> clUnrollFullThreshold(
"unroll-full-threshold", llvm::cl::Hidden,
llvm::cl::desc(
@@ -53,19 +58,25 @@ namespace {
/// with trip count less than the specified threshold. The latter is for testing
/// purposes, especially for testing outer loop unrolling.
struct LoopUnroll : public FunctionPass {
- Optional<unsigned> unrollFactor;
- Optional<bool> unrollFull;
-
- explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
- Optional<bool> unrollFull = None)
+ const Optional<unsigned> unrollFactor;
+ const Optional<bool> unrollFull;
+ // Callback to obtain unroll factors; if this has a callable target, takes
+ // precedence over command-line argument or passed argument.
+ const std::function<unsigned(const ForStmt &)> getUnrollFactor;
+
+ explicit LoopUnroll(
+ Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
+ const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
- unrollFull(unrollFull) {}
+ unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnMLFunction(MLFunction *f) override;
/// Unroll this for stmt. Returns false if nothing was done.
bool runOnForStmt(ForStmt *forStmt);
+ static const unsigned kDefaultUnrollFactor = 4;
+
static char passID;
};
} // end anonymous namespace
@@ -144,16 +155,33 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
return success();
}
- InnermostLoopGatherer ilg;
- ilg.walkPostOrder(f);
- auto &loops = ilg.loops;
- for (auto *forStmt : loops)
- runOnForStmt(forStmt);
+ unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0
+ ? clUnrollNumRepetitions
+ : 1;
+ // If the call back is provided, we will recurse until no loops are found.
+ for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
+ InnermostLoopGatherer ilg;
+ ilg.walkPostOrder(f);
+ auto &loops = ilg.loops;
+ if (loops.empty())
+ break;
+ bool unrolled = false;
+ for (auto *forStmt : loops)
+ unrolled |= runOnForStmt(forStmt);
+ if (!unrolled)
+ // Break out if nothing was unrolled.
+ break;
+ }
return success();
}
-/// Unroll a 'for' stmt. Default unroll factor is 4.
+/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false
+/// otherwise. The default unroll factor is 4.
bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+ // Use the function callback if one was provided.
+ if (getUnrollFactor) {
+ return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt));
+ }
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
return loopUnrollByFactor(forStmt, unrollFactor.getValue());
@@ -166,13 +194,15 @@ bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
return loopUnrollFull(forStmt);
// Unroll by four otherwise.
- return loopUnrollByFactor(forStmt, 4);
+ return loopUnrollByFactor(forStmt, kDefaultUnrollFactor);
}
-FunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
- return new LoopUnroll(unrollFactor == -1 ? None
- : Optional<unsigned>(unrollFactor),
- unrollFull == -1 ? None : Optional<bool>(unrollFull));
+FunctionPass *mlir::createLoopUnrollPass(
+ int unrollFactor, int unrollFull,
+ const std::function<unsigned(const ForStmt &)> &getUnrollFactor) {
+ return new LoopUnroll(
+ unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
+ unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
}
static PassRegistration<LoopUnroll> pass("loop-unroll", "Unroll loops");
OpenPOWER on IntegriCloud