summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopUnrollAndJam.cpp
blob: b8e98379051b58b16f3c911ab1881f266c8d5bb6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
//===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements loop unroll and jam for MLFunctions. Unroll and jam is a
// transformation that improves locality, in particular, register reuse, while
// also improving instruction level parallelism. The example below shows what it
// does in nearly the general case. Loop unroll and jam currently works if the
// bounds of the loops inner to the loop being unroll-jammed do not depend on
// the latter.
//
// Before      After unroll and jam of i by factor 2:
//
//             for i, step = 2
// for i         S1(i);
//   S1;         S2(i);
//   S2;         S1(i+1);
//   for j       S2(i+1);
//     S3;       for j
//     S4;         S3(i, j);
//   S5;           S4(i, j);
//   S6;           S3(i+1, j)
//                 S4(i+1, j)
//               S5(i);
//               S6(i);
//               S5(i+1);
//               S6(i+1);
//
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
// stmt's, bodies of those loops will not be jammed.
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"

#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"

using namespace mlir;

// Loop unroll and jam factor.
static llvm::cl::opt<unsigned>
    clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden,
                      llvm::cl::desc("Use this unroll jam factor for all loops"
                                     " (default 4)"));

namespace {
/// Loop unroll jam pass. Currently, this just unroll jams the first
/// outer loop in an MLFunction.
struct LoopUnrollAndJam : public MLFunctionPass {
  Optional<unsigned> unrollJamFactor;
  static const unsigned kDefaultUnrollJamFactor = 4;

  explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor)
      : unrollJamFactor(unrollJamFactor) {}

  PassResult runOnMLFunction(MLFunction *f) override;
  bool runOnForStmt(ForStmt *forStmt);
};
} // end anonymous namespace

MLFunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
  return new LoopUnrollAndJam(
      unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
}

PassResult LoopUnrollAndJam::runOnMLFunction(MLFunction *f) {
  // Currently, just the outermost loop from the first loop nest is
  // unroll-and-jammed by this pass. However, runOnForStmt can be called on any
  // for Stmt.
  if (!isa<ForStmt>(f->begin()))
    return success();

  auto *forStmt = cast<ForStmt>(f->begin());
  runOnForStmt(forStmt);
  return success();
}

/// Unroll and jam a 'for' stmt. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
  // Unroll and jam by the factor that was passed if any.
  if (unrollJamFactor.hasValue())
    return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue());
  // Otherwise, unroll jam by the command-line factor if one was specified.
  if (clUnrollJamFactor.getNumOccurrences() > 0)
    return loopUnrollJamByFactor(forStmt, clUnrollJamFactor);

  // Unroll and jam by four otherwise.
  return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
}

bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);

  if (mayBeConstantTripCount.hasValue() &&
      mayBeConstantTripCount.getValue() < unrollJamFactor)
    return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
  return loopUnrollJamByFactor(forStmt, unrollJamFactor);
}

/// Unrolls and jams this loop by the specified factor.
bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
  // Gathers all maximal sub-blocks of statements that do not themselves include
  // a for stmt (a statement could have a descendant for stmt though in its
  // tree).
  class JamBlockGatherer : public StmtWalker<JamBlockGatherer> {
  public:
    typedef llvm::iplist<Statement> StmtListType;

    // Store iterators to the first and last stmt of each sub-block found.
    std::vector<std::pair<StmtBlock::iterator, StmtBlock::iterator>> subBlocks;

    // This is a linear time walk.
    void walk(StmtListType::iterator Start, StmtListType::iterator End) {
      for (auto it = Start; it != End;) {
        auto subBlockStart = it;
        while (it != End && !isa<ForStmt>(it))
          ++it;
        if (it != subBlockStart)
          subBlocks.push_back({subBlockStart, std::prev(it)});
        // Process all for stmts that appear next.
        while (it != End && isa<ForStmt>(it))
          walkForStmt(cast<ForStmt>(it++));
      }
    }
  };

  assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");

  if (unrollJamFactor == 1 || forStmt->getStatements().empty())
    return false;

  Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);

  if (!mayBeConstantTripCount.hasValue() &&
      getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
    return false;

  auto lbMap = forStmt->getLowerBoundMap();
  auto ubMap = forStmt->getUpperBoundMap();

  // Loops with max/min expressions won't be unrolled here (the output can't be
  // expressed as an MLFunction in the general case). However, the right way to
  // do such unrolling for an MLFunction would be to specialize the loop for the
  // 'hotspot' case and unroll that hotspot.
  if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
    return false;

  // Same operand list for lower and upper bound for now.
  // TODO(bondhugula): handle bounds with different sets of operands.
  if (!forStmt->matchingBoundOperandList())
    return false;

  // If the trip count is lower than the unroll jam factor, no unroll jam.
  // TODO(bondhugula): option to specify cleanup loop unrolling.
  if (mayBeConstantTripCount.hasValue() &&
      mayBeConstantTripCount.getValue() < unrollJamFactor)
    return false;

  // Gather all sub-blocks to jam upon the loop being unrolled.
  JamBlockGatherer jbg;
  jbg.walkForStmt(forStmt);
  auto &subBlocks = jbg.subBlocks;

  // Generate the cleanup loop if trip count isn't a multiple of
  // unrollJamFactor.
  if (mayBeConstantTripCount.hasValue() &&
      mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
    DenseMap<const MLValue *, MLValue *> operandMap;
    // Insert the cleanup loop right after 'forStmt'.
    MLFuncBuilder builder(forStmt->getBlock(),
                          std::next(StmtBlock::iterator(forStmt)));
    auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
    cleanupForStmt->setLowerBoundMap(
        getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));

    // The upper bound needs to be adjusted.
    forStmt->setUpperBoundMap(
        getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));

    // Promote the loop body up if this has turned into a single iteration loop.
    promoteIfSingleIteration(cleanupForStmt);
  }

  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
  int64_t step = forStmt->getStep();
  forStmt->setStep(step * unrollJamFactor);

  for (auto &subBlock : subBlocks) {
    // Builder to insert unroll-jammed bodies. Insert right at the end of
    // sub-block.
    MLFuncBuilder builder(subBlock.first->getBlock(),
                          std::next(subBlock.second));

    // Unroll and jam (appends unrollJamFactor-1 additional copies).
    for (unsigned i = 1; i < unrollJamFactor; i++) {
      DenseMap<const MLValue *, MLValue *> operandMapping;

      // If the induction variable is used, create a remapping to the value for
      // this unrolled instance.
      if (!forStmt->use_empty()) {
        // iv' = iv + i, i = 1 to unrollJamFactor-1.
        auto d0 = builder.getAffineDimExpr(0);
        auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
        auto *ivUnroll =
            builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
                ->getResult(0);
        operandMapping[forStmt] = cast<MLValue>(ivUnroll);
      }
      // Clone the sub-block being unroll-jammed.
      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
        builder.clone(*it, operandMapping);
      }
    }
  }

  // Promote the loop body up if this has turned into a single iteration loop.
  promoteIfSingleIteration(forStmt);

  return true;
}
OpenPOWER on IntegriCloud