summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/DialectConversion.cpp
blob: 49fb9373d1fd4cd76c06ea9dea52421b7a52acb5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a generic pass for converting between MLIR dialects.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Transforms/Utils.h"

using namespace mlir;

namespace mlir {
namespace impl {
// Implementation detail class of the DialectConversion pass.  Performs
// function-by-function conversions by creating new functions, filling them in
// with converted blocks, updating the function attributes, and replacing the
// old functions with the new ones in the module.
class FunctionConversion {
public:
  // Entry point.  Uses hooks defined in `conversion` to obtain the list of
  // conversion patterns and to convert function and block argument types.
  // Converts the `module` in-place by replacing all existing functions with the
  // converted ones.
  static LogicalResult convert(DialectConversion *conversion, Module *module);

private:
  // Constructs a FunctionConversion by storing the hooks.
  explicit FunctionConversion(DialectConversion *conversion)
      : dialectConversion(conversion) {}

  // Utility that looks up a list of value in the value remapping table. Returns
  // an empty vector if one of the values is not mapped yet.
  SmallVector<Value *, 4> lookupValues(
      const llvm::iterator_range<Operation::operand_iterator> &operands);

  // Converts the given function to the dialect using hooks defined in
  // `dialectConversion`.  Returns the converted function or `nullptr` on error.
  Function *convertFunction(Function *f);

  // Converts the given region starting from the entry block and following the
  // block successors.  Returns the converted region or `nullptr` on error.
  template <typename RegionParent>
  std::unique_ptr<Region> convertRegion(MLIRContext *context, Region *region,
                                        RegionParent *parent);

  // Converts an operation with successors.  Extracts the converted operands
  // from `valueRemapping` and the converted blocks from `blockRemapping`, and
  // passes them to `converter->rewriteTerminator` function defined in the
  // pattern, together with `builder`.
  LogicalResult convertOpWithSuccessors(DialectOpConversion *converter,
                                        Operation *op, FuncBuilder &builder);

  // Converts an operation without successors.  Extracts the converted operands
  // from `valueRemapping` and passes them to the `converter->rewrite` function
  // defined in the pattern, together with `builder`.
  LogicalResult convertOp(DialectOpConversion *converter, Operation *op,
                          FuncBuilder &builder);

  // Converts a block by traversing its operations sequentially, looking for
  // the first pattern match and dispatching the operation conversion to
  // either `convertOp` or `convertOpWithSuccessors` depending on the presence
  // of successors.  If there is no match, clones the operation.
  //
  // After converting operations, traverses the successor blocks unless they
  // have been visited already as indicated in `visitedBlocks`.
  LogicalResult convertBlock(Block *block, FuncBuilder &builder,
                             llvm::DenseSet<Block *> &visitedBlocks);

  // Converts the module as follows.
  // 1. Call `convertFunction` on each function of the module and collect the
  // mapping between old and new functions.
  // 2. Remap all function attributes in the new functions to point to the new
  // functions instead of the old ones.
  // 3. Replace old functions with the new in the module.
  LogicalResult run(Module *m);

  // Pointer to a specific dialect pass.
  DialectConversion *dialectConversion;

  // Set of known conversion patterns.
  llvm::DenseSet<DialectOpConversion *> conversions;

  // Mapping between values(blocks) in the original function and in the new
  // function.
  BlockAndValueMapping mapping;
};
} // end namespace impl
} // end namespace mlir

SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
    const llvm::iterator_range<Operation::operand_iterator> &operands) {
  SmallVector<Value *, 4> remapped;
  remapped.reserve(llvm::size(operands));
  for (Value *operand : operands) {
    Value *value = mapping.lookupOrNull(operand);
    if (!value)
      return {};
    remapped.push_back(value);
  }
  return remapped;
}

LogicalResult impl::FunctionConversion::convertOpWithSuccessors(
    DialectOpConversion *converter, Operation *op, FuncBuilder &builder) {
  SmallVector<Block *, 2> destinations;
  destinations.reserve(op->getNumSuccessors());
  SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
  assert((!operands.empty() || op->getNumOperands() == 0) &&
         "converting op before ops defining its operands");

  SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
  unsigned numSuccessorOperands = 0;
  for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i)
    numSuccessorOperands += op->getNumSuccessorOperands(i);
  unsigned seen = 0;
  unsigned firstSuccessorOperand = op->getNumOperands() - numSuccessorOperands;
  for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) {
    Block *successor = mapping.lookupOrNull(op->getSuccessor(i));
    assert(successor && "block was not remapped");
    destinations.push_back(successor);
    unsigned n = op->getNumSuccessorOperands(i);
    operandsPerDestination.push_back(
        llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
    seen += n;
  }
  converter->rewriteTerminator(
      op,
      llvm::makeArrayRef(operands.data(),
                         operands.data() + firstSuccessorOperand),
      destinations, operandsPerDestination, builder);
  return success();
}

LogicalResult
impl::FunctionConversion::convertOp(DialectOpConversion *converter,
                                    Operation *op, FuncBuilder &builder) {
  auto operands = lookupValues(op->getOperands());
  assert((!operands.empty() || op->getNumOperands() == 0) &&
         "converting op before ops defining its operands");

  auto results = converter->rewrite(op, operands, builder);
  if (results.size() != op->getNumResults())
    return (op->emitError("rewriting produced a different number of results"),
            failure());

  for (unsigned i = 0, e = results.size(); i < e; ++i)
    mapping.map(op->getResult(i), results[i]);
  return success();
}

LogicalResult
impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
                                       llvm::DenseSet<Block *> &visitedBlocks) {
  // First, add the current block to the list of visited blocks.
  visitedBlocks.insert(block);
  // Setup the builder to the insert to the converted block.
  builder.setInsertionPointToStart(mapping.lookupOrNull(block));

  // Iterate over ops and convert them.
  for (Operation &op : *block) {
    // Find the first matching conversion and apply it.
    bool converted = false;
    for (auto *conversion : conversions) {
      if (!conversion->match(&op))
        continue;

      if (op.getNumSuccessors() != 0) {
        if (failed(convertOpWithSuccessors(conversion, &op, builder)))
          return failure();
      } else if (failed(convertOp(conversion, &op, builder))) {
        return failure();
      }
      converted = true;
      break;
    }
    // If there is no conversion provided for the op, clone the op and convert
    // its regions, if any.
    if (!converted) {
      auto *newOp = builder.cloneWithoutRegions(op, mapping);
      for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
        auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
        newOp->getRegion(i).takeBody(*newRegion);
      }
    }
  }

  // Recurse to children unless they have been already visited.
  for (Block *succ : block->getSuccessors()) {
    if (visitedBlocks.count(succ) != 0)
      continue;
    if (failed(convertBlock(succ, builder, visitedBlocks)))
      return failure();
  }
  return success();
}

template <typename RegionParent>
std::unique_ptr<Region>
impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region,
                                        RegionParent *parent) {
  assert(region && "expected a region");
  auto newRegion = llvm::make_unique<Region>(parent);
  if (region->empty())
    return newRegion;

  auto emitError = [context](llvm::Twine f) -> std::unique_ptr<Region> {
    context->emitError(UnknownLoc::get(context), f.str());
    return nullptr;
  };

  // Create new blocks and convert their arguments.
  for (Block &block : *region) {
    auto *newBlock = new Block;
    newRegion->push_back(newBlock);
    mapping.map(&block, newBlock);
    for (auto *arg : block.getArguments()) {
      auto convertedType = dialectConversion->convertType(arg->getType());
      if (!convertedType)
        return emitError("could not convert block argument type");
      newBlock->addArgument(convertedType);
      mapping.map(arg, *newBlock->args_rbegin());
    }
  }

  // Start a DFS-order traversal of the CFG to make sure defs are converted
  // before uses in dominated blocks.
  llvm::DenseSet<Block *> visitedBlocks;
  FuncBuilder builder(&newRegion->front());
  if (failed(convertBlock(&region->front(), builder, visitedBlocks)))
    return nullptr;

  // If some blocks are not reachable through successor chains, they should have
  // been removed by the DCE before this.

  if (visitedBlocks.size() != std::distance(region->begin(), region->end()))
    return emitError("unreachable blocks were not converted");
  return newRegion;
}

Function *impl::FunctionConversion::convertFunction(Function *f) {
  assert(f && "expected function");
  MLIRContext *context = f->getContext();
  auto emitError = [context](llvm::Twine f) -> Function * {
    context->emitError(UnknownLoc::get(context), f.str());
    return nullptr;
  };

  // Create a new function with argument types and result types converted. Wrap
  // it into a unique_ptr to make sure it is cleaned up in case of error.
  SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
  Type newFunctionType = dialectConversion->convertFunctionSignatureType(
      f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs);
  if (!newFunctionType)
    return emitError("could not convert function type");
  auto newFunction = llvm::make_unique<Function>(
      f->getLoc(), f->getName().strref(), newFunctionType.cast<FunctionType>(),
      f->getAttrs(), newFunctionArgAttrs);

  // Return early if the function has no blocks.
  if (f->getBlocks().empty())
    return newFunction.release();

  auto newBody = convertRegion(context, &f->getBody(), f);
  if (!newBody)
    return emitError("could not convert function body");
  newFunction->getBody().takeBody(*newBody);

  return newFunction.release();
}

LogicalResult impl::FunctionConversion::convert(DialectConversion *conversion,
                                                Module *module) {
  return impl::FunctionConversion(conversion).run(module);
}

LogicalResult impl::FunctionConversion::run(Module *module) {
  if (!module)
    return failure();

  MLIRContext *context = module->getContext();
  conversions = dialectConversion->initConverters(context);

  // Convert the functions but don't add them to the module yet to avoid
  // converted functions to be converted again.
  SmallVector<Function *, 0> originalFuncs, convertedFuncs;
  DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
  originalFuncs.reserve(module->getFunctions().size());
  for (auto &func : *module)
    originalFuncs.push_back(&func);
  convertedFuncs.reserve(module->getFunctions().size());
  for (auto *func : originalFuncs) {
    Function *converted = convertFunction(func);
    if (!converted)
      return failure();

    auto origFuncAttr = FunctionAttr::get(func, context);
    auto convertedFuncAttr = FunctionAttr::get(converted, context);
    convertedFuncs.push_back(converted);
    functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
  }

  // Remap function attributes in the converted functions (they are not yet in
  // the module).  Original functions will disappear anyway so there is no
  // need to remap attributes in them.
  for (const auto &funcPair : functionAttrRemapping) {
    remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
  }

  // Remove original functions from the module, then insert converted
  // functions.  The order is important to avoid name collisions.
  for (auto &func : originalFuncs)
    func->erase();
  for (auto *func : convertedFuncs)
    module->getFunctions().push_back(func);

  return success();
}

// Create a function type with arguments and results converted, and argument
// attributes passed through.
FunctionType DialectConversion::convertFunctionSignatureType(
    FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
    SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
  SmallVector<Type, 8> arguments;
  SmallVector<Type, 4> results;

  arguments.reserve(type.getNumInputs());
  for (auto t : type.getInputs())
    arguments.push_back(convertType(t));

  results.reserve(type.getNumResults());
  for (auto t : type.getResults())
    results.push_back(convertType(t));

  // Note this will cause an extra allocation only if we need
  // to grow the caller-provided resulting attribute vector.
  convertedArgAttrs.reserve(arguments.size());
  for (auto attr : argAttrs)
    convertedArgAttrs.push_back(attr);

  return FunctionType::get(arguments, results, type.getContext());
}

LogicalResult DialectConversion::convert(Module *m) {
  return impl::FunctionConversion::convert(this, m);
}
OpenPOWER on IntegriCloud