summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnrollAndJam.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopUnrollAndJam.cpp')
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp86
1 files changed, 47 insertions, 39 deletions
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 7deaf850362..7327a37ee3a 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -43,6 +43,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -80,7 +81,7 @@ struct LoopUnrollAndJam : public FunctionPass {
unrollJamFactor(unrollJamFactor) {}
PassResult runOnFunction(Function *f) override;
- bool runOnForInst(ForInst *forInst);
+ bool runOnAffineForOp(OpPointer<AffineForOp> forOp);
static char passID;
};
@@ -95,47 +96,51 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
- // unroll-and-jammed by this pass. However, runOnForInst can be called on any
- // for Inst.
+ // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
+ // any for Inst.
auto &entryBlock = f->front();
if (!entryBlock.empty())
- if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front()))
- runOnForInst(forInst);
+ if (auto forOp =
+ cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>())
+ runOnAffineForOp(forOp);
return success();
}
/// Unroll and jam a 'for' inst. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
-bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
+bool LoopUnrollAndJam::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
// Unroll and jam by the factor that was passed if any.
if (unrollJamFactor.hasValue())
- return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
+ return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue());
// Otherwise, unroll jam by the command-line factor if one was specified.
if (clUnrollJamFactor.getNumOccurrences() > 0)
- return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
+ return loopUnrollJamByFactor(forOp, clUnrollJamFactor);
// Unroll and jam by four otherwise.
- return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
+ return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor);
}
-bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+bool mlir::loopUnrollJamUpToFactor(OpPointer<AffineForOp> forOp,
+ uint64_t unrollJamFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
- return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
- return loopUnrollJamByFactor(forInst, unrollJamFactor);
+ return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
+ return loopUnrollJamByFactor(forOp, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
-bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
+bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
+ uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of instructions that do not themselves
// include a for inst (a instruction could have a descendant for inst though
// in its tree).
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
using InstListType = llvm::iplist<Instruction>;
+ using InstWalker<JamBlockGatherer>::walk;
// Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
@@ -144,30 +149,30 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
- while (it != End && !isa<ForInst>(it))
+ while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
- while (it != End && isa<ForInst>(it))
- walkForInst(cast<ForInst>(it++));
+ while (it != End && cast<OperationInst>(it)->isa<AffineForOp>())
+ walk(&*it++);
}
}
};
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- if (unrollJamFactor == 1 || forInst->getBody()->empty())
+ if (unrollJamFactor == 1 || forOp->getBody()->empty())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (!mayBeConstantTripCount.hasValue() &&
- getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
+ getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0)
return false;
- auto lbMap = forInst->getLowerBoundMap();
- auto ubMap = forInst->getUpperBoundMap();
+ auto lbMap = forOp->getLowerBoundMap();
+ auto ubMap = forOp->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@@ -178,7 +183,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
- if (!forInst->matchingBoundOperandList())
+ if (!forOp->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
@@ -187,35 +192,38 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
mayBeConstantTripCount.getValue() < unrollJamFactor)
return false;
+ auto *forInst = forOp->getInstruction();
+
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
- jbg.walkForInst(forInst);
+ jbg.walkOpInst(forInst);
auto &subBlocks = jbg.subBlocks;
// Generate the cleanup loop if trip count isn't a multiple of
// unrollJamFactor.
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
- // Insert the cleanup loop right after 'forInst'.
+ // Insert the cleanup loop right after 'forOp'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
- auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst));
- cleanupForInst->setLowerBoundMap(
- getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
+ auto cleanupAffineForOp =
+ cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
+ cleanupAffineForOp->setLowerBoundMap(
+ getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
- forInst->setUpperBoundMap(
- getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
+ forOp->setUpperBoundMap(
+ getUnrolledLoopUpperBound(forOp, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(cleanupForInst);
+ promoteIfSingleIteration(cleanupAffineForOp);
}
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
- int64_t step = forInst->getStep();
- forInst->setStep(step * unrollJamFactor);
+ int64_t step = forOp->getStep();
+ forOp->setStep(step * unrollJamFactor);
- auto *forInstIV = forInst->getInductionVar();
+ auto *forOpIV = forOp->getInductionVar();
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
// sub-block.
@@ -227,13 +235,13 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forInstIV->use_empty()) {
+ if (!forOpIV->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
- auto ivUnroll = builder.create<AffineApplyOp>(forInst->getLoc(),
- bumpMap, forInstIV);
- operandMapping.map(forInstIV, ivUnroll);
+ auto ivUnroll =
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV);
+ operandMapping.map(forOpIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
@@ -243,7 +251,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
}
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(forInst);
+ promoteIfSingleIteration(forOp);
return true;
}
OpenPOWER on IntegriCloud