summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnrollAndJam.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopUnrollAndJam.cpp')
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp100
1 files changed, 50 insertions, 50 deletions
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 60e8d154f98..f59659cf234 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -40,7 +40,7 @@
// S6(i+1);
//
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
-// stmt's, bodies of those loops will not be jammed.
+// inst's, bodies of those loops will not be jammed.
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
@@ -49,7 +49,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass {
unrollJamFactor(unrollJamFactor) {}
PassResult runOnMLFunction(Function *f) override;
- bool runOnForStmt(ForStmt *forStmt);
+ bool runOnForInst(ForInst *forInst);
static char passID;
};
@@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
- // unroll-and-jammed by this pass. However, runOnForStmt can be called on any
- // for Stmt.
- auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
- if (!forStmt)
+ // unroll-and-jammed by this pass. However, runOnForInst can be called on any
+ // for Inst.
+ auto *forInst = dyn_cast<ForInst>(f->getBody()->begin());
+ if (!forInst)
return success();
- runOnForStmt(forStmt);
+ runOnForInst(forInst);
return success();
}
-/// Unroll and jam a 'for' stmt. Default unroll jam factor is
+/// Unroll and jam a 'for' inst. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
-bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
+bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
// Unroll and jam by the factor that was passed if any.
if (unrollJamFactor.hasValue())
- return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue());
+ return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
// Otherwise, unroll jam by the command-line factor if one was specified.
if (clUnrollJamFactor.getNumOccurrences() > 0)
- return loopUnrollJamByFactor(forStmt, clUnrollJamFactor);
+ return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
// Unroll and jam by four otherwise.
- return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
+ return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
}
-bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
- return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
- return loopUnrollJamByFactor(forStmt, unrollJamFactor);
+ return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
+ return loopUnrollJamByFactor(forInst, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
-bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
- // Gathers all maximal sub-blocks of statements that do not themselves include
- // a for stmt (a statement could have a descendant for stmt though in its
- // tree).
- class JamBlockGatherer : public StmtWalker<JamBlockGatherer> {
+bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
+ // Gathers all maximal sub-blocks of instructions that do not themselves
+ // include a for inst (a instruction could have a descendant for inst though
+ // in its tree).
+ class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
- using InstListType = llvm::iplist<Statement>;
+ using InstListType = llvm::iplist<Instruction>;
- // Store iterators to the first and last stmt of each sub-block found.
+ // Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
// This is a linear time walk.
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
- while (it != End && !isa<ForStmt>(it))
+ while (it != End && !isa<ForInst>(it))
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
- // Process all for stmts that appear next.
- while (it != End && isa<ForStmt>(it))
- walkForStmt(cast<ForStmt>(it++));
+ // Process all for insts that appear next.
+ while (it != End && isa<ForInst>(it))
+ walkForInst(cast<ForInst>(it++));
}
}
};
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- if (unrollJamFactor == 1 || forStmt->getBody()->empty())
+ if (unrollJamFactor == 1 || forInst->getBody()->empty())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (!mayBeConstantTripCount.hasValue() &&
- getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
+ getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
return false;
- auto lbMap = forStmt->getLowerBoundMap();
- auto ubMap = forStmt->getUpperBoundMap();
+ auto lbMap = forInst->getLowerBoundMap();
+ auto ubMap = forInst->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
- if (!forStmt->matchingBoundOperandList())
+ if (!forInst->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
@@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
- jbg.walkForStmt(forStmt);
+ jbg.walkForInst(forInst);
auto &subBlocks = jbg.subBlocks;
// Generate the cleanup loop if trip count isn't a multiple of
@@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
- // Insert the cleanup loop right after 'forStmt'.
- FuncBuilder builder(forStmt->getBlock(),
- std::next(Block::iterator(forStmt)));
- auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
- cleanupForStmt->setLowerBoundMap(
- getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
+ // Insert the cleanup loop right after 'forInst'.
+ FuncBuilder builder(forInst->getBlock(),
+ std::next(Block::iterator(forInst)));
+ auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
+ cleanupForInst->setLowerBoundMap(
+ getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
- forStmt->setUpperBoundMap(
- getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
+ forInst->setUpperBoundMap(
+ getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(cleanupForStmt);
+ promoteIfSingleIteration(cleanupForInst);
}
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
- int64_t step = forStmt->getStep();
- forStmt->setStep(step * unrollJamFactor);
+ int64_t step = forInst->getStep();
+ forInst->setStep(step * unrollJamFactor);
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
@@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forStmt->use_empty()) {
+ if (!forInst->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
- builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
->getResult(0);
- operandMapping[forStmt] = ivUnroll;
+ operandMapping[forInst] = ivUnroll;
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
@@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
}
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(forStmt);
+ promoteIfSingleIteration(forInst);
return true;
}
OpenPOWER on IntegriCloud