summaryrefslogtreecommitdiffstats
path: root/mlir/docs/Tutorials/Toy/Ch-6.md
blob: a25f0d2091d9ef0794ef6385a5b3fcdc4834be44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# Chapter 6: Lowering to LLVM and CodeGeneration

[TOC]

In the [previous chapter](Ch-5.md), we introduced the
[dialect conversion](../../DialectConversion.md) framework and partially lowered
many of the `Toy` operations to affine loop nests for optimization. In this
chapter, we will finally lower to LLVM for code generation.

# Lowering to LLVM

For this lowering, we will again use the dialect conversion framework to perform
the heavy lifting. However, this time, we will be performing a full conversion
to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
lowered all but one of the `toy` operations, with the last being `toy.print`.
Before going over the conversion to LLVM, let's lower the `toy.print` operation.
We will lower this operation to a non-affine loop nest that invokes `printf` for
each element. Note that, because the dialect conversion framework supports
[transitive lowering](Glossary.md#transitive-lowering), we don't need to
directly emit operations in the LLVM dialect. By transitive lowering, we mean
that the conversion framework may apply multiple patterns to fully legalize an
operation. In this example, we are generating a structured loop nest instead of
the branch-form in the LLVM dialect. As long as we then have a lowering from the
loop operations to LLVM, the lowering will still succeed.

During lowering we can get, or build, the declaration for printf as so:

```c++
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
                                           ModuleOp module,
                                           LLVM::LLVMDialect *llvmDialect) {
  auto *context = module.getContext();
  if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
    return SymbolRefAttr::get("printf", context);

  // Create a function declaration for printf, the signature is:
  //   * `i32 (i8*, ...)`
  auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
  auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
  auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
                                                  /*isVarArg=*/true);

  // Insert the printf function into the body of the parent module.
  PatternRewriter::InsertionGuard insertGuard(rewriter);
  rewriter.setInsertionPointToStart(module.getBody());
  rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
  return SymbolRefAttr::get("printf", context);
}
```

Now that the lowering for the printf operation has been defined, we can specify
the components necessary for the lowering. These are largely the same as the
components defined in the [previous chapter](Ch-5.md).

## Conversion Target

For this conversion, aside from the top-level module, we will be lowering
everything to the LLVM dialect.

```c++
  mlir::ConversionTarget target(getContext());
  target.addLegalDialect<mlir::LLVM::LLVMDialect>();
  target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
```

## Type Converter

This lowering will also transform the MemRef types which are currently being
operated on into a representation in LLVM. To perform this conversion, we use a
TypeConverter as part of the lowering. This converter specifies how one type
maps to another. This is necessary now that we are performing more complicated
lowerings involving block arguments. Given that we don't have any
Toy-dialect-specific types that need to be lowered, the default converter is
enough for our use case.

```c++
  LLVMTypeConverter typeConverter(&getContext());
```

## Conversion Patterns

Now that the conversion target has been defined, we need to provide the patterns
used for lowering. At this point in the compilation process, we have a
combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and
`affine` dialects already provide the set of patterns needed to transform them
into LLVM dialect. These patterns allow for lowering the IR in multiple stages
by relying on [transitive lowering](Glossary.md#transitive-lowering).

```c++
  mlir::OwningRewritePatternList patterns;
  mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
  mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
  mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);

  // The only remaining operation to lower from the `toy` dialect, is the
  // PrintOp.
  patterns.insert<PrintOpLowering>(&getContext());
```

## Full Lowering

We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
that only legal operations will remain after the conversion.

```c++
  mlir::ModuleOp module = getModule();
  if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
                                             &typeConverter)))
    signalPassFailure();
```

Looking back at our current working example:

```mlir
func @main() {
  %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
  %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
  %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
  "toy.print"(%3) : (tensor<3x2xf64>) -> ()
  "toy.return"() : () -> ()
}
```

We can now lower down to the LLVM dialect, which produces the following code:

```mlir
llvm.func @free(!llvm<"i8*">)
llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
llvm.func @main() {
  %0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
  %1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double

  ...

^bb16:
  %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %222 = llvm.mlir.constant(0 : index) : !llvm.i64
  %223 = llvm.mlir.constant(2 : index) : !llvm.i64
  %224 = llvm.mul %214, %223 : !llvm.i64
  %225 = llvm.add %222, %224 : !llvm.i64
  %226 = llvm.mlir.constant(1 : index) : !llvm.i64
  %227 = llvm.mul %219, %226 : !llvm.i64
  %228 = llvm.add %225, %227 : !llvm.i64
  %229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
  %230 = llvm.load %229 : !llvm<"double*">
  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
  %232 = llvm.add %219, %218 : !llvm.i64
  llvm.br ^bb15(%232 : !llvm.i64)

  ...

^bb18:
  %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
  %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
  %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
  llvm.return
}
```

See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for
more in-depth details on lowering to the LLVM dialect.

# CodeGen: Getting Out of MLIR

At this point we are right at the cusp of code generation. We can generate code
in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
run it.

## Emitting LLVM IR

Now that our module is comprised only of operations in the LLVM dialect, we can
export to LLVM IR. To do this programmatically, we can invoke the following
utility:

```c++
  std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
  if (!llvmModule)
    /* ... an error was encountered ... */
```

Exporting our module to LLVM IR generates:

```llvm
define void @main() {
  ...

102:
  %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
  %104 = mul i64 %96, 2
  %105 = add i64 0, %104
  %106 = mul i64 %100, 1
  %107 = add i64 %105, %106
  %108 = getelementptr double, double* %103, i64 %107
  %109 = load double, double* %108
  %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
  %111 = add i64 %100, 1
  br label %99

  ...

115:
  %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
  %117 = bitcast double* %116 to i8*
  call void @free(i8* %117)
  %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
  %119 = bitcast double* %118 to i8*
  call void @free(i8* %119)
  %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
  %121 = bitcast double* %120 to i8*
  call void @free(i8* %121)
  ret void
}
```

If we enable optimization on the generated LLVM IR, we can trim this down quite
a bit:

```llvm
define void @main()
  %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
  %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
  %putchar = tail call i32 @putchar(i32 10)
  %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
  %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
  %putchar.1 = tail call i32 @putchar(i32 10)
  %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
  %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
  %putchar.2 = tail call i32 @putchar(i32 10)
  ret void
}
```

The full code listing for dumping LLVM IR can be found in `Ch6/toy.cpp` in the
`dumpLLVMIR()` function:

```c++

int dumpLLVMIR(mlir::ModuleOp module) {
  // Translate the module, that contains the LLVM dialect, to LLVM IR.
  auto llvmModule = mlir::translateModuleToLLVMIR(module);
  if (!llvmModule) {
    llvm::errs() << "Failed to emit LLVM IR\n";
    return -1;
  }

  // Initialize LLVM targets.
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());

  /// Optionally run an optimization pipeline over the llvm module.
  auto optPipeline = mlir::makeOptimizingTransformer(
      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
      /*targetMachine=*/nullptr);
  if (auto err = optPipeline(llvmModule.get())) {
    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
    return -1;
  }
  llvm::errs() << *llvmModule << "\n";
  return 0;
}
```

## Setting up a JIT

Setting up a JIT to run the module containing the LLVM dialect can be done using
the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
the JIT can be found in `Ch6/toy.cpp` in the `runJit()` function:

```c++
int runJit(mlir::ModuleOp module) {
  // Initialize LLVM targets.
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();

  // An optimization pipeline to use within the execution engine.
  auto optPipeline = mlir::makeOptimizingTransformer(
      /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
      /*targetMachine=*/nullptr);

  // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
  // the module.
  auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
  assert(maybeEngine && "failed to construct an execution engine");
  auto &engine = maybeEngine.get();

  // Invoke the JIT-compiled function.
  auto invocationResult = engine->invoke("main");
  if (invocationResult) {
    llvm::errs() << "JIT invocation failed\n";
    return -1;
  }

  return 0;
}
```

You can play around with it from the build directory:

```shell
$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
1.000000 2.000000
3.000000 4.000000
```

You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
`-emit=llvm` to compare the various levels of IR involved. Also try options like
[`--print-ir-after-all`](../../WritingAPass.md#ir-printing) to track the
evolution of the IR throughout the pipeline.

So far, we have worked with primitive data types. In the
[next chapter](Ch-7.md), we will add a composite `struct` type.
OpenPOWER on IntegriCloud