1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
|
//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a generic pass for converting between MLIR dialects.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Transforms/Utils.h"
using namespace mlir;
namespace mlir {
namespace impl {
// Implementation detail class of the DialectConversion pass. Performs
// function-by-function conversions by creating new functions, filling them in
// with converted blocks, updating the function attributes, and replacing the
// old functions with the new ones in the module.
class FunctionConversion {
public:
// Entry point. Uses hooks defined in `conversion` to obtain the list of
// conversion patterns and to convert function and block argument types.
// Converts the `module` in-place by replacing all existing functions with the
// converted ones.
static bool convert(DialectConversion *conversion, Module *module);
private:
// Constructs a FunctionConversion by storing the hooks.
explicit FunctionConversion(DialectConversion *conversion)
: dialectConversion(conversion) {}
// Utility that looks up a list of value in the value remapping table. Returns
// an empty vector if one of the values is not mapped yet.
SmallVector<Value *, 4>
lookupValues(const llvm::iterator_range<Instruction::const_operand_iterator>
&operands);
// Converts the given function to the dialect using hooks defined in
// `dialectConversion`. Returns the converted function or `nullptr` on error.
Function *convertFunction(Function *f);
// Converts an operation with successors. Extracts the converted operands
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
// passes them to `converter->rewriteTerminator` function defined in the
// pattern, together with `builder`.
bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op,
FuncBuilder &builder);
// Converts an operation without successors. Extracts the converted operands
// from `valueRemapping` and passes them to the `converter->rewrite` function
// defined in the pattern, together with `builder`.
bool convertOp(DialectOpConversion *converter, Instruction *op,
FuncBuilder &builder);
// Converts a block by traversing its instructions sequentially, looking for
// the first pattern match and dispatching the instruction conversion to
// either `convertOp` or `convertOpWithSuccessors` depending on the presence
// of successors. If there is no match, clones the operation.
//
// After converting operations, traverses the successor blocks unless they
// have been visited already as indicated in `visitedBlocks`.
//
// Return `true` on error.
bool convertBlock(Block *block, FuncBuilder &builder,
llvm::DenseSet<Block *> &visitedBlocks);
// Converts the module as follows.
// 1. Call `convertFunction` on each function of the module and collect the
// mapping between old and new functions.
// 2. Remap all function attributes in the new functions to point to the new
// functions instead of the old ones.
// 3. Replace old functions with the new in the module.
bool run(Module *m);
// Pointer to a specific dialect pass.
DialectConversion *dialectConversion;
// Set of known conversion patterns.
llvm::DenseSet<DialectOpConversion *> conversions;
// Mapping between values(blocks) in the original function and in the new
// function.
BlockAndValueMapping mapping;
};
} // end namespace impl
} // end namespace mlir
SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
const llvm::iterator_range<Instruction::const_operand_iterator> &operands) {
SmallVector<Value *, 4> remapped;
remapped.reserve(llvm::size(operands));
for (const Value *operand : operands) {
Value *value = mapping.lookupOrNull(operand);
if (!value)
return {};
remapped.push_back(value);
}
return remapped;
}
bool impl::FunctionConversion::convertOpWithSuccessors(
DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) &&
"converting op before ops defining its operands");
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
unsigned numSuccessorOperands = 0;
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i)
numSuccessorOperands += op->getNumSuccessorOperands(i);
unsigned seen = 0;
unsigned firstSuccessorOperand = op->getNumOperands() - numSuccessorOperands;
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) {
Block *successor = mapping.lookupOrNull(op->getSuccessor(i));
assert(successor && "block was not remapped");
destinations.push_back(successor);
unsigned n = op->getNumSuccessorOperands(i);
operandsPerDestination.push_back(
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
seen += n;
}
converter->rewriteTerminator(
op,
llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand),
destinations, operandsPerDestination, builder);
return false;
}
bool impl::FunctionConversion::convertOp(DialectOpConversion *converter,
Instruction *op,
FuncBuilder &builder) {
auto operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) &&
"converting op before ops defining its operands");
auto results = converter->rewrite(op, operands, builder);
if (results.size() != op->getNumResults())
return op->emitError("rewriting produced a different number of results");
for (unsigned i = 0, e = results.size(); i < e; ++i)
mapping.map(op->getResult(i), results[i]);
return false;
}
bool impl::FunctionConversion::convertBlock(
Block *block, FuncBuilder &builder,
llvm::DenseSet<Block *> &visitedBlocks) {
// First, add the current block to the list of visited blocks.
visitedBlocks.insert(block);
// Setup the builder to the insert to the converted block.
builder.setInsertionPointToStart(mapping.lookupOrNull(block));
// Iterate over ops and convert them.
for (Instruction &inst : *block) {
if (inst.getNumBlockLists() != 0) {
inst.emitError("unsupported region instruction");
return true;
}
// Find the first matching conversion and apply it.
bool converted = false;
for (auto *conversion : conversions) {
if (!conversion->match(&inst))
continue;
if (inst.isTerminator() && inst.getNumSuccessors() > 0) {
if (convertOpWithSuccessors(conversion, &inst, builder))
return true;
} else if (convertOp(conversion, &inst, builder)) {
return true;
}
converted = true;
break;
}
// If there is no conversion provided for the op, clone the op as is.
if (!converted)
builder.clone(inst, mapping);
}
// Recurse to children unless they have been already visited.
for (Block *succ : block->getSuccessors()) {
if (visitedBlocks.count(succ) != 0)
continue;
if (convertBlock(succ, builder, visitedBlocks))
return true;
}
return false;
}
Function *impl::FunctionConversion::convertFunction(Function *f) {
assert(f && "expected function");
MLIRContext *context = f->getContext();
auto emitError = [context](llvm::Twine f) -> Function * {
context->emitError(UnknownLoc::get(context), f.str());
return nullptr;
};
// Create a new function with argument types and result types converted. Wrap
// it into a unique_ptr to make sure it is cleaned up in case of error.
Type newFunctionType = dialectConversion->convertType(f->getType());
if (!newFunctionType)
return emitError("could not convert function type");
auto newFunction = llvm::make_unique<Function>(
f->getLoc(), f->getName().strref(), newFunctionType.cast<FunctionType>(),
f->getAttrs());
// Return early if the function has no blocks.
if (f->getBlocks().empty())
return newFunction.release();
// Create blocks in the new function and convert types of their arguments.
FuncBuilder builder(newFunction.get());
for (Block &block : *f) {
auto *newBlock = builder.createBlock();
mapping.map(&block, newBlock);
for (auto *arg : block.getArguments()) {
auto convertedType = dialectConversion->convertType(arg->getType());
if (!convertedType)
return emitError("could not convert block argument type");
newBlock->addArgument(convertedType);
mapping.map(arg, *newBlock->args_rbegin());
}
}
// Start a DFS-order traversal of the CFG to make sure defs are converted
// before uses in dominated blocks.
llvm::DenseSet<Block *> visitedBlocks;
if (convertBlock(&f->front(), builder, visitedBlocks))
return nullptr;
// If some blocks are not reachable through successor chains, they should have
// been removed by the DCE before this.
if (visitedBlocks.size() != f->getBlocks().size())
return emitError("unreachable blocks were not converted");
return newFunction.release();
}
bool impl::FunctionConversion::convert(DialectConversion *conversion,
Module *module) {
return impl::FunctionConversion(conversion).run(module);
}
bool impl::FunctionConversion::run(Module *module) {
if (!module)
return true;
MLIRContext *context = module->getContext();
conversions = dialectConversion->initConverters(context);
// Convert the functions but don't add them to the module yet to avoid
// converted functions to be converted again.
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
originalFuncs.reserve(module->getFunctions().size());
for (auto &func : *module)
originalFuncs.push_back(&func);
convertedFuncs.reserve(module->getFunctions().size());
for (auto *func : originalFuncs) {
Function *converted = convertFunction(func);
if (!converted)
return true;
auto origFuncAttr = FunctionAttr::get(func, context);
auto convertedFuncAttr = FunctionAttr::get(converted, context);
convertedFuncs.push_back(converted);
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
}
// Remap function attributes in the converted functions (they are not yet in
// the module). Original functions will disappear anyway so there is no
// need to remap attributes in them.
for (const auto &funcPair : functionAttrRemapping) {
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
}
// Remove original functions from the module, then insert converted
// functions. The order is important to avoid name collisions.
for (auto &func : originalFuncs)
func->erase();
for (auto *func : convertedFuncs)
module->getFunctions().push_back(func);
return false;
}
PassResult DialectConversion::runOnModule(Module *m) {
return impl::FunctionConversion::convert(this, m) ? failure() : success();
}
|