summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnroll.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopUnroll.cpp')
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp67
1 files changed, 35 insertions, 32 deletions
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 6d63e4afd2d..86e913bd71f 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -21,6 +21,7 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -70,18 +71,19 @@ struct LoopUnroll : public FunctionPass {
const Optional<bool> unrollFull;
// Callback to obtain unroll factors; if this has a callable target, takes
// precedence over command-line argument or passed argument.
- const std::function<unsigned(const ForInst &)> getUnrollFactor;
+ const std::function<unsigned(ConstOpPointer<AffineForOp>)> getUnrollFactor;
- explicit LoopUnroll(
- Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
- const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
+ explicit LoopUnroll(Optional<unsigned> unrollFactor = None,
+ Optional<bool> unrollFull = None,
+ const std::function<unsigned(ConstOpPointer<AffineForOp>)>
+ &getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnFunction(Function *f) override;
/// Unroll this for inst. Returns false if nothing was done.
- bool runOnForInst(ForInst *forInst);
+ bool runOnAffineForOp(OpPointer<AffineForOp> forOp);
static const unsigned kDefaultUnrollFactor = 4;
@@ -96,7 +98,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
- std::vector<ForInst *> loops;
+ std::vector<OpPointer<AffineForOp>> loops;
// This method specialized to encode custom return logic.
using InstListType = llvm::iplist<Instruction>;
@@ -111,20 +113,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return hasInnerLoops;
}
- bool walkForInstPostOrder(ForInst *forInst) {
- bool hasInnerLoops =
- walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
- if (!hasInnerLoops)
- loops.push_back(forInst);
- return true;
- }
-
bool walkOpInstPostOrder(OperationInst *opInst) {
+ bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
- if (walkPostOrder(block.begin(), block.end()))
- return true;
- return false;
+ hasInnerLoops |= walkPostOrder(block.begin(), block.end());
+ if (opInst->isa<AffineForOp>()) {
+ if (!hasInnerLoops)
+ loops.push_back(opInst->cast<AffineForOp>());
+ return true;
+ }
+ return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@@ -137,14 +136,17 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
- std::vector<ForInst *> loops;
+ std::vector<OpPointer<AffineForOp>> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- void visitForInst(ForInst *forInst) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
+ void visitOperationInst(OperationInst *opInst) {
+ auto forOp = opInst->dyn_cast<AffineForOp>();
+ if (!forOp)
+ return;
+ Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
- loops.push_back(forInst);
+ loops.push_back(forOp);
}
};
@@ -156,8 +158,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
- for (auto *forInst : loops)
- loopUnrollFull(forInst);
+ for (auto forOp : loops)
+ loopUnrollFull(forOp);
return success();
}
@@ -172,8 +174,8 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
if (loops.empty())
break;
bool unrolled = false;
- for (auto *forInst : loops)
- unrolled |= runOnForInst(forInst);
+ for (auto forOp : loops)
+ unrolled |= runOnAffineForOp(forOp);
if (!unrolled)
// Break out if nothing was unrolled.
break;
@@ -183,29 +185,30 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
/// otherwise. The default unroll factor is 4.
-bool LoopUnroll::runOnForInst(ForInst *forInst) {
+bool LoopUnroll::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// Use the function callback if one was provided.
if (getUnrollFactor) {
- return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
+ return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
}
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
- return loopUnrollByFactor(forInst, unrollFactor.getValue());
+ return loopUnrollByFactor(forOp, unrollFactor.getValue());
// Unroll by the command line factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
- return loopUnrollByFactor(forInst, clUnrollFactor);
+ return loopUnrollByFactor(forOp, clUnrollFactor);
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
- return loopUnrollFull(forInst);
+ return loopUnrollFull(forOp);
// Unroll by four otherwise.
- return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
+ return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
}
FunctionPass *mlir::createLoopUnrollPass(
int unrollFactor, int unrollFull,
- const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
+ const std::function<unsigned(ConstOpPointer<AffineForOp>)>
+ &getUnrollFactor) {
return new LoopUnroll(
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
OpenPOWER on IntegriCloud