summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnroll.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp68
1 files changed, 34 insertions, 34 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 15ea0f841cc..69431bf6349 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -26,7 +26,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass {
const Optional<bool> unrollFull;
// Callback to obtain unroll factors; if this has a callable target, takes
// precedence over command-line argument or passed argument.
- const std::function<unsigned(const ForStmt &)> getUnrollFactor;
+ const std::function<unsigned(const ForInst &)> getUnrollFactor;
explicit LoopUnroll(
Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
- const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr)
+ const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnMLFunction(Function *f) override;
- /// Unroll this for stmt. Returns false if nothing was done.
- bool runOnForStmt(ForStmt *forStmt);
+ /// Unroll this for inst. Returns false if nothing was done.
+ bool runOnForInst(ForInst *forInst);
static const unsigned kDefaultUnrollFactor = 4;
@@ -85,13 +85,13 @@ char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnMLFunction(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
+ class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
- std::vector<ForStmt *> loops;
+ std::vector<ForInst *> loops;
// This method specialized to encode custom return logic.
- using InstListType = llvm::iplist<Statement>;
+ using InstListType = llvm::iplist<Instruction>;
bool walkPostOrder(InstListType::iterator Start,
InstListType::iterator End) {
bool hasInnerLoops = false;
@@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
return hasInnerLoops;
}
- bool walkForStmtPostOrder(ForStmt *forStmt) {
+ bool walkForInstPostOrder(ForInst *forInst) {
bool hasInnerLoops =
- walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
+ walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
if (!hasInnerLoops)
- loops.push_back(forStmt);
+ loops.push_back(forInst);
return true;
}
- bool walkIfStmtPostOrder(IfStmt *ifStmt) {
+ bool walkIfInstPostOrder(IfInst *ifInst) {
bool hasInnerLoops =
- walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
- if (ifStmt->hasElse())
+ walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
+ if (ifInst->hasElse())
hasInnerLoops |=
- walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
+ walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
return hasInnerLoops;
}
- bool visitOperationInst(OperationInst *opStmt) { return false; }
+ bool visitOperationInst(OperationInst *opInst) { return false; }
// FIXME: can't use base class method for this because that in turn would
// need to use the derived class method above. CRTP doesn't allow it, and
// the compiler error resulting from it is also misleading.
- using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
+ using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
// Gathers all loops with trip count <= minTripCount.
- class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
+ class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
- std::vector<ForStmt *> loops;
+ std::vector<ForInst *> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- void visitForStmt(ForStmt *forStmt) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
+ void visitForInst(ForInst *forInst) {
+ Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
- loops.push_back(forStmt);
+ loops.push_back(forInst);
}
};
@@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
- for (auto *forStmt : loops)
- loopUnrollFull(forStmt);
+ for (auto *forInst : loops)
+ loopUnrollFull(forInst);
return success();
}
@@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
if (loops.empty())
break;
bool unrolled = false;
- for (auto *forStmt : loops)
- unrolled |= runOnForStmt(forStmt);
+ for (auto *forInst : loops)
+ unrolled |= runOnForInst(forInst);
if (!unrolled)
// Break out if nothing was unrolled.
break;
@@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
return success();
}
-/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false
+/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
/// otherwise. The default unroll factor is 4.
-bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+bool LoopUnroll::runOnForInst(ForInst *forInst) {
// Use the function callback if one was provided.
if (getUnrollFactor) {
- return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt));
+ return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
}
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
- return loopUnrollByFactor(forStmt, unrollFactor.getValue());
+ return loopUnrollByFactor(forInst, unrollFactor.getValue());
// Unroll by the command line factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
- return loopUnrollByFactor(forStmt, clUnrollFactor);
+ return loopUnrollByFactor(forInst, clUnrollFactor);
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
- return loopUnrollFull(forStmt);
+ return loopUnrollFull(forInst);
// Unroll by four otherwise.
- return loopUnrollByFactor(forStmt, kDefaultUnrollFactor);
+ return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
}
FunctionPass *mlir::createLoopUnrollPass(
int unrollFactor, int unrollFull,
- const std::function<unsigned(const ForStmt &)> &getUnrollFactor) {
+ const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
return new LoopUnroll(
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
OpenPOWER on IntegriCloud