diff options
Diffstat (limited to 'mlir/docs/Tutorials/Toy')
-rw-r--r-- | mlir/docs/Tutorials/Toy/Ch-1.md | 169 | ||||
-rwxr-xr-x | mlir/docs/Tutorials/Toy/Ch-2.md | 577 | ||||
-rw-r--r-- | mlir/docs/Tutorials/Toy/Ch-3.md | 264 | ||||
-rw-r--r-- | mlir/docs/Tutorials/Toy/Ch-4.md | 387 | ||||
-rw-r--r-- | mlir/docs/Tutorials/Toy/Ch-5.md | 357 | ||||
-rw-r--r-- | mlir/docs/Tutorials/Toy/Ch-6.md | 323 | ||||
-rw-r--r-- | mlir/docs/Tutorials/Toy/Ch-7.md | 539 |
7 files changed, 2616 insertions, 0 deletions
diff --git a/mlir/docs/Tutorials/Toy/Ch-1.md b/mlir/docs/Tutorials/Toy/Ch-1.md new file mode 100644 index 00000000000..cb7f97cb3f6 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-1.md @@ -0,0 +1,169 @@ +# Chapter 1: Toy Tutorial Introduction + +[TOC] + +This tutorial runs through the implementation of a basic toy language on top of +MLIR. The goal of this tutorial is to introduce the concepts of MLIR; in +particular, how [dialects](../../LangRef.md#dialects) can help easily support +language specific constructs and transformations while still offering an easy +path to lower to LLVM or other codegen infrastructure. This tutorial is based on +the model of the +[LLVM Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html). + +This tutorial assumes you have cloned and built MLIR; if you have not yet done +so, see +[Getting started with MLIR](https://github.com/tensorflow/mlir#getting-started-with-mlir). + +## The Chapters + +This tutorial is divided in the following chapters: + +- [Chapter #1](Ch-1.md): Introduction to the Toy language and the definition + of its AST. +- [Chapter #2](Ch-2.md): Traversing the AST to emit a dialect in MLIR, + introducing base MLIR concepts. Here we show how to start attaching + semantics to our custom operations in MLIR. +- [Chapter #3](Ch-3.md): High-level language-specific optimization using + pattern rewriting system. +- [Chapter #4](Ch-4.md): Writing generic dialect-independent transformations + with Interfaces. Here we will show how to plug dialect specific information + into generic transformations like shape inference and inlining. +- [Chapter #5](Ch-5.md): Partially lowering to lower-level dialects. We'll + convert some our high level language specific semantics towards a generic + affine oriented dialect for optimization. +- [Chapter #6](Ch-6.md): Lowering to LLVM and code generation. Here we'll + target LLVM IR for code generation, and detail more of the lowering + framework. +- [Chapter #7](Ch-7.md): Extending Toy: Adding support for a composite type. + We'll demonstrate how to add a custom type to MLIR, and how it fits in the + existing pipeline. + +## The Language + +This tutorial will be illustrated with a toy language that we’ll call “Toy” +(naming is hard...). Toy is a tensor-based language that allows you to define +functions, perform some math computation, and print results. + +Given that we want to keep things simple, the codegen will be limited to tensors +of rank <= 2, and the only datatype in Toy is a 64-bit floating point type (aka +‘double’ in C parlance). As such, all values are implicitly double precision, +`Values` are immutable (i.e. every operation returns a newly allocated value), +and deallocation is automatically managed. But enough with the long description; +nothing is better than walking through an example to get a better understanding: + +```Toy {.toy} +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + + # b is identical to a, the literal tensor is implicitly reshaped: defining new + # variables is the way to reshape tensors (element count must match). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # transpose() and print() are the only builtin, the following will transpose + # a and b and perform an element-wise multiplication before printing the result. + print(transpose(a) * transpose(b)); +} +``` + +Type checking is statically performed through type inference; the language only +requires type declarations to specify tensor shapes when needed. Functions are +generic: their parameters are unranked (in other words, we know these are +tensors, but we don't know their dimensions). They are specialized for every +newly discovered signature at call sites. Let's revisit the previous example by +adding a user-defined function: + +```Toy {.toy} +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + var a = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <3, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return <3, 2>. + var d = multiply_transpose(b, a); + + # A new call with <3, 2> (instead of <2, 3>) for both dimensions will + # trigger another specialization of `multiply_transpose`. + var e = multiply_transpose(c, d); + + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var f = multiply_transpose(transpose(a), c); +} +``` + +## The AST + +The AST from the above code is fairly straightforward; here is a dump of it: + +``` +Module: + Function + Proto 'multiply_transpose' @test/ast.toy:5:1' + Args: [a, b] + Block { + Return + BinOp: * @test/ast.toy:6:25 + Call 'transpose' [ @test/ast.toy:6:10 + var: a @test/ast.toy:6:20 + ] + Call 'transpose' [ @test/ast.toy:6:25 + var: b @test/ast.toy:6:35 + ] + } // Block + Function + Proto 'main' @test/ast.toy:9:1' + Args: [] + Block { + VarDecl a<> @test/ast.toy:11:3 + Literal: <2, 3>[<3>[1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[4.000000e+00, 5.000000e+00, 6.000000e+00]] @test/ast.toy:11:17 + VarDecl b<2, 3> @test/ast.toy:12:3 + Literal: <6>[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @test/ast.toy:12:17 + VarDecl c<> @test/ast.toy:15:3 + Call 'multiply_transpose' [ @test/ast.toy:15:11 + var: a @test/ast.toy:15:30 + var: b @test/ast.toy:15:33 + ] + VarDecl d<> @test/ast.toy:18:3 + Call 'multiply_transpose' [ @test/ast.toy:18:11 + var: b @test/ast.toy:18:30 + var: a @test/ast.toy:18:33 + ] + VarDecl e<> @test/ast.toy:21:3 + Call 'multiply_transpose' [ @test/ast.toy:21:11 + var: b @test/ast.toy:21:30 + var: c @test/ast.toy:21:33 + ] + VarDecl f<> @test/ast.toy:24:3 + Call 'multiply_transpose' [ @test/ast.toy:24:11 + Call 'transpose' [ @test/ast.toy:24:30 + var: a @test/ast.toy:24:40 + ] + var: c @test/ast.toy:24:44 + ] + } // Block +``` + +You can reproduce this result and play with the example in the +`examples/toy/Ch1/` directory; try running `path/to/BUILD/bin/toyc-ch1 +test/Examples/Toy/Ch1/ast.toy -emit=ast`. + +The code for the lexer is fairly straightforward; it is all in a single header: +`examples/toy/Ch1/include/toy/Lexer.h`. The parser can be found in +`examples/toy/Ch1/include/toy/Parser.h`; it is a recursive descent parser. If +you are not familiar with such a Lexer/Parser, these are very similar to the +LLVM Kaleidoscope equivalent that are detailed in the first two chapters of the +[Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl02.html). + +The [next chapter](Ch-2.md) will demonstrate how to convert this AST into MLIR. diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md new file mode 100755 index 00000000000..ce46788f4ae --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-2.md @@ -0,0 +1,577 @@ +# Chapter 2: Emitting Basic MLIR + +[TOC] + +Now that we're familiar with our language and the AST, let's see how MLIR can +help to compile Toy. + +## Introduction: Multi-Level Intermediate Representation + +Other compilers, like LLVM (see the +[Kaleidoscope tutorial](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html)), +offer a fixed set of predefined types and (usually *low-level* / RISC-like) +instructions. It is up to the frontend for a given language to perform any +language-specific type-checking, analysis, or transformation before emitting +LLVM IR. For example, Clang will use its AST to perform not only static analysis +but also transformations, such as C++ template instantiation through AST cloning +and rewrite. Finally, languages with construction at a higher-level than C/C++ +may require non-trivial lowering from their AST to generate LLVM IR. + +As a consequence, multiple frontends end up reimplementing significant pieces of +infrastructure to support the need for these analyses and transformation. MLIR +addresses this issue by being designed for extensibility. As such, there are few +pre-defined instructions (*operations* in MLIR terminology) or types. + +## Interfacing with MLIR + +[Language reference](../../LangRef.md) + +MLIR is designed to be a completely extensible infrastructure; there is no +closed set of attributes (think: constant metadata), operations, or types. MLIR +supports this extensibility with the concept of +[Dialects](../../LangRef.md#dialects). Dialects provide a grouping mechanism for +abstraction under a unique `namespace`. + +In MLIR, [`Operations`](../../LangRef.md#operations) are the core unit of +abstraction and computation, similar in many ways to LLVM instructions. +Operations can have application-specific semantics and can be used to represent +all of the core IR structures in LLVM: instructions, globals (like functions), +modules, etc. + +Here is the MLIR assembly for the Toy `transpose` operations: + +```mlir +%t_tensor = "toy.transpose"(%tensor) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64> loc("example/file/path":12:1) +``` + +Let's break down the anatomy of this MLIR operation: + +- `%t_tensor` + + * The name given to the result defined by this operation (which includes + [a prefixed sigil to avoid collisions](../../LangRef.md#identifiers-and-keywords)). + An operation may define zero or more results (in the context of Toy, we + will limit ourselves to single-result operations), which are SSA values. + The name is used during parsing but is not persistent (e.g., it is not + tracked in the in-memory representation of the SSA value). + +- `"toy.transpose"` + + * The name of the operation. It is expected to be a unique string, with + the namespace of the dialect prefixed before the "`.`". This can be read + as the `transpose` operation in the `toy` dialect. + +- `(%tensor)` + + * A list of zero or more input operands (or arguments), which are SSA + values defined by other operations or referring to block arguments. + +- `{ inplace = true }` + + * A dictionary of zero or more attributes, which are special operands that + are always constant. Here we define a boolean attribute named 'inplace' + that has a constant value of true. + +- `(tensor<2x3xf64>) -> tensor<3x2xf64>` + + * This refers to the type of the operation in a functional form, spelling + the types of the arguments in parentheses and the type of the return + values afterward. + +- `loc("example/file/path":12:1)` + + * This is the location in the source code from which this operation + originated. + +Shown here is the general form of an operation. As described above, the set of +operations in MLIR is extensible. This means that the infrastructure must be +able to opaquely reason about the structure of an operation. This is done by +boiling down the composition of an operation into discrete pieces: + +- A name for the operation. +- A list of SSA operand values. +- A list of [attributes](../../LangRef.md#attributes). +- A list of [types](../../LangRef.md#type-system) for result values. +- A [source location](../../Diagnostics.md#source-locations) for debugging + purposes. +- A list of successors [blocks](../../LangRef.md#blocks) (for branches, + mostly). +- A list of [regions](../../LangRef.md#regions) (for structural operations + like functions). + +In MLIR, every operation has a mandatory source location associated with it. +Contrary to LLVM, where debug info locations are metadata and can be dropped, in +MLIR, the location is a core requirement, and APIs depend on and manipulate it. +Dropping a location is thus an explicit choice which cannot happen by mistake. + +To provide an illustration: If a transformation replaces an operation by +another, that new operation must still have a location attached. This makes it +possible to track where that operation came from. + +It's worth noting that the mlir-opt tool - a tool for testing +compiler passes - does not include locations in the output by default. The +`-mlir-print-debuginfo` flag specifies to include locations. (Run `mlir-opt +--help` for more options.) + +### Opaque API + +MLIR is designed to be a completely extensible system, and as such, the +infrastructure has the capability to opaquely represent all of its core +components: attributes, operations, types, etc. This allows MLIR to parse, +represent, and [round-trip](../../Glossary.md#round-trip) any valid IR. For +example, we could place our Toy operation from above into an `.mlir` file and +round-trip through *mlir-opt* without registering anything: + +```mlir +func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { + %t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64> + return %t_tensor : tensor<3x2xf64> +} +``` + +In the cases of unregistered attributes, operations, and types, MLIR will +enforce some structural constraints (SSA, block termination, etc.), but +otherwise they are completely opaque. This can be useful for bootstrapping +purposes, but it is generally advised against. Opaque operations must be treated +conservatively by transformations and analyses, and they are much harder to +construct and manipulate. + +This handling can be observed by crafting what should be an invalid IR for Toy +and seeing it round-trip without tripping the verifier: + +```mlir +// RUN: toyc %s -emit=mlir + +func @main() { + %0 = "toy.print"() : () -> tensor<2x3xf64> +} +``` + +There are multiple problems here: the `toy.print` operation is not a terminator; +it should take an operand; and it shouldn't return any values. In the next +section, we will register our dialect and operations with MLIR, plug into the +verifier, and add nicer APIs to manipulate our operations. + +## Defining a Toy Dialect + +To effectively interface with MLIR, we will define a new Toy dialect. This +dialect will properly model the semantics of the Toy language, as well as +provide an easy avenue for high-level analysis and transformation. + +```c++ +/// This is the definition of the Toy dialect. A dialect inherits from +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods, which will be demonstrated in later chapters of the tutorial. +class ToyDialect : public mlir::Dialect { + public: + explicit ToyDialect(mlir::MLIRContext *ctx); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities. + static llvm::StringRef getDialectNamespace() { return "toy"; } +}; +``` + +The dialect can now be registered in the global registry: + +```c++ + mlir::registerDialect<ToyDialect>(); +``` + +Any new `MLIRContext` created from now on will contain an instance of the Toy +dialect and invoke specific hooks for things like parsing attributes and types. + +## Defining Toy Operations + +Now that we have a `Toy` dialect, we can start registering operations. This will +allow for providing semantic information that the rest of the system can hook +into. Let's walk through the creation of the `toy.constant` operation: + +```mlir + %4 = "toy.constant"() {value = dense<1.0> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +``` + +This operation takes zero operands, a +[dense elements](../../LangRef.md#dense-elements-attribute) attribute named +`value`, and returns a single result of +[TensorType](../../LangRef.md#tensor-type). An operation inherits from the +[CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) +`mlir::Op` class which also takes some optional [*traits*](../../Traits.md) to +customize its behavior. These traits may provide additional accessors, +verification, etc. + +```c++ +class ConstantOp : public mlir::Op<ConstantOp, + /// The ConstantOp takes zero inputs. + mlir::OpTrait::ZeroOperands, + /// The ConstantOp returns a single result. + mlir::OpTrait::OneResult, + /// The ConstantOp is pure and has no visible side-effects. + mlir::OpTrait::HasNoSideEffect> { + + public: + /// Inherit the constructors from the base Op class. + using Op::Op; + + /// Provide the unique name for this operation. MLIR will use this to register + /// the operation and uniquely identify it throughout the system. + static llvm::StringRef getOperationName() { return "toy.constant"; } + + /// Return the value of the constant by fetching it from the attribute. + mlir::DenseElementsAttr getValue(); + + /// Operations can provide additional verification beyond the traits they + /// define. Here we will ensure that the specific invariants of the constant + /// operation are upheld, for example the result type must be of TensorType. + LogicalResult verify(); + + /// Provide an interface to build this operation from a set of input values. + /// This interface is used by the builder to allow for easily generating + /// instances of this operation: + /// mlir::OpBuilder::create<ConstantOp>(...) + /// This method populates the given `state` that MLIR uses to create + /// operations. This state is a collection of all of the discrete elements + /// that an operation may contain. + /// Build a constant with the given return type and `value` attribute. + static void build(mlir::Builder *builder, mlir::OperationState &state, + mlir::Type result, mlir::DenseElementsAttr value); + /// Build a constant and reuse the type from the given 'value'. + static void build(mlir::Builder *builder, mlir::OperationState &state, + mlir::DenseElementsAttr value); + /// Build a constant by broadcasting the given 'value'. + static void build(mlir::Builder *builder, mlir::OperationState &state, + double value); +}; +``` + +and we register this operation in the `ToyDialect` constructor: + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx) { + addOperations<ConstantOp>(); +} +``` + +### Op vs Operation: Using MLIR Operations + +Now that we have defined an operation, we will want to access and transform it. +In MLIR, there are two main classes related to operations: `Operation` and `Op`. +Operation is the actual opaque instance of the operation, and represents the +general API into an operation instance. An `Op` is the base class of a derived +operation, like `ConstantOp`, and acts as smart pointer wrapper around a +`Operation*`. This means that when we define our Toy operations, we are actually +providing a clean interface for building and interfacing with the `Operation` +class; this is why our `ConstantOp` defines no class fields. Therefore, we +always pass these classes around by value, instead of by reference or pointer +(*passing by value* is a common idiom and applies similarly to attributes, +types, etc). We can always get an instance of our toy operation by using LLVM's +casting infrastructure: + +```c++ +void processConstantOp(mlir::Operation *operation) { + ConstantOp op = llvm::dyn_cast<ConstantOp>(operation); + + // This operation is not an instance of `ConstantOp`. + if (!op) + return; + + // Get the internal operation instance back. + mlir::Operation *internalOperation = op.getOperation(); + assert(internalOperation == operation && + "these operation instances are the same"); +} +``` + +### Using the Operation Definition Specification (ODS) Framework + +In addition to specializing the `mlir::Op` C++ template, MLIR also supports +defining operations in a declarative manner. This is achieved via the +[Operation Definition Specification](../../OpDefinitions.md) framework. Facts +regarding an operation are specified concisely into a TableGen record, which +will be expanded into an equivalent `mlir::Op` C++ template specialization at +compile time. Using the ODS framework is the desired way for defining operations +in MLIR given the simplicity, conciseness, and general stability in the face of +C++ API changes. + +Lets see how to define the ODS equivalent of our ConstantOp: + +The first thing to do is to define a link to the Toy dialect that we defined in +C++. This is used to link all of the operations that we will define to our +dialect: + +```tablegen +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + // The namespace of our dialect, this corresponds 1-1 with the string we + // provided in `ToyDialect::getDialectNamespace`. + let name = "toy"; + + // The C++ namespace that the dialect class definition resides in. + let cppNamespace = "toy"; +} +``` + +Now that we have defined a link to the Toy dialect, we can start defining +operations. Operations in ODS are defined by inheriting from the `Op` class. To +simplify our operation definitions, we will define a base class for operations +in the Toy dialect. + +```tablegen +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op<string mnemonic, list<OpTrait> traits = []> : + Op<Toy_Dialect, mnemonic, traits>; +``` + +With all of the preliminary pieces defined, we can begin to define the constant +operation. + +We define a toy operation by inheriting from our base 'Toy_Op' class above. Here +we provide the mnemonic and a list of traits for the operation. The +[mnemonic](../../OpDefinitions.md#operation-name) here matches the one given in +`ConstantOp::getOperationName` without the dialect prefix; `toy.`. The constant +operation here is also marked as 'NoSideEffect'. This is an ODS trait, and +matches one-to-one with the trait we providing when defining `ConstantOp`: +`mlir::OpTrait::HasNoSideEffect`. Missing here from our C++ definition are the +`ZeroOperands` and `OneResult` traits; these will be automatically inferred +based upon the `arguments` and `results` fields we define later. + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { +} +``` + +At this point you probably might want to know what the C++ code generated by +TableGen looks like. Simply run the `mlir-tblgen` command with the +`gen-op-decls` or the `gen-op-defs` action like so: + +``` +${build_root}/bin/mlir-tblgen -gen-op-defs ${mlir_src_root}/examples/toy/Ch2/include/toy/Ops.td -I ${mlir_src_root}/include/ +``` + +Depending on the selected action, this will print either the `ConstantOp` class +declaration or its implementation. Comparing this output to the hand-crafted +implementation is incredibly useful when getting started with TableGen. + +#### Defining Arguments and Results + +With the shell of the operation defined, we can now provide the +[inputs](../../OpDefinitions.md#operation-arguments) and +[outputs](../../OpDefinitions.md#operation-results) to our operation. The +inputs, or arguments, to an operation may be attributes or types for SSA operand +values. The results correspond to a set of types for the values produced by the +operation: + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); +} +``` + +By providing a name to the arguments or results, e.g. `$value`, ODS will +automatically generate a matching accessor: `DenseElementsAttr +ConstantOp::value()`. + +#### Adding Documentation + +The next step after defining the operation is to document it. Operations may +provide +[`summary` and `description`](../../OpDefinitions.md#operation-documentation) +fields to describe the semantics of the operation. This information is useful +for users of the dialect and can even be used to auto-generate Markdown +documents. + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant operation"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + }]; + + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The generic call operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); +} +``` + +#### Verifying Operation Semantics + +At this point we've already covered a majority of the original C++ operation +definition. The next piece to define is the verifier. Luckily, much like the +named accessor, the ODS framework will automatically generate a lot of the +necessary verification logic based upon the constraints we have given. This +means that we don't need to verify the structure of the return type, or even the +input attribute `value`. In many cases, additional verification is not even +necessary for ODS operations. To add additional verification logic, an operation +can override the [`verifier`](../../OpDefinitions.md#custom-verifier-code) +field. The `verifier` field allows for defining a C++ code blob that will be run +as part of `ConstantOp::verify`. This blob can assume that all of the other +invariants of the operation have already been verified: + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant operation"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + }]; + + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The generic call operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); + + // Add additional verification logic to the constant operation. Here we invoke + // a static `verify` method in a C++ source file. This codeblock is executed + // inside of ConstantOp::verify, so we can use `this` to refer to the current + // operation instance. + let verifier = [{ return ::verify(*this); }]; +} +``` + +#### Attaching `build` Methods + +The final missing component here from our original C++ example are the `build` +methods. ODS can generate some simple build methods automatically, and in this +case it will generate our first build method for us. For the rest, we define the +[`builders`](../../OpDefinitions.md#custom-builder-methods) field. This field +takes a list of `OpBuilder` objects that take a string corresponding to a list +of C++ parameters, as well as an optional code block that can be used to specify +the implementation inline. + +```tablegen +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant operation"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + }]; + + // The constant operation takes an attribute as the only input. + // `F64ElementsAttr` corresponds to a 64-bit floating-point ElementsAttr. + let arguments = (ins F64ElementsAttr:$value); + + // The generic call operation returns a single value of TensorType. + // F64Tensor corresponds to a 64-bit floating-point TensorType. + let results = (outs F64Tensor); + + // Add additional verification logic to the constant operation. Here we invoke + // a static `verify` method in a c++ source file. This codeblock is executed + // inside of ConstantOp::verify, so we can use `this` to refer to the current + // operation instance. + let verifier = [{ return ::verify(*this); }]; + + // Add custom build methods for the constant operation. These methods populate + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create<ConstantOp>(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "DenseElementsAttr value", [{ + // Call into an autogenerated `build` method. + build(builder, result, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. This builder + // creates a declaration for `ConstantOp::build` with the given parameters. + OpBuilder<"Builder *builder, OperationState &result, double value"> + ]; +} +``` + +Above we introduce several of the concepts for defining operations in the ODS +framework, but there are many more that we haven't had a chance to: regions, +variadic operands, etc. Check out the +[full specification](../../OpDefinitions.md) for more details. + +## Complete Toy Example + +At this point we can generate our "Toy IR". A simplified version of the previous +example: + +```.toy +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return transpose(a) * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} +``` + +Results in the following IR: + +```mlir +module { + func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:10) + %1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:25) + %2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:25) + "toy.return"(%2) : (tensor<*xf64>) -> () loc("test/codegen.toy":5:3) + } loc("test/codegen.toy":4:1) + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> loc("test/codegen.toy":9:17) + %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> loc("test/codegen.toy":9:3) + %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> loc("test/codegen.toy":10:17) + %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> loc("test/codegen.toy":10:3) + %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/codegen.toy":11:11) + %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/codegen.toy":12:11) + "toy.print"(%5) : (tensor<*xf64>) -> () loc("test/codegen.toy":13:3) + "toy.return"() : () -> () loc("test/codegen.toy":8:1) + } loc("test/codegen.toy":8:1) +} loc("test/codegen.toy":0:0) +``` + +You can build `toyc-ch2` and try yourself: `toyc-ch2 +test/Examples/Toy/Ch2/codegen.toy -emit=mlir -mlir-print-debuginfo`. We can also +check our RoundTrip: `toyc-ch2 test/Examples/Toy/Ch2/codegen.toy -emit=mlir +-mlir-print-debuginfo 2> codegen.mlir` followed by `toyc-ch2 codegen.mlir +-emit=mlir`. You should also use `mlir-tblgen` on the final definition file and +study the generated C++ code. + +At this point, MLIR knows about our Toy dialect and operations. In the +[next chapter](Ch-3.md), we will leverage our new dialect to implement some +high-level language-specific analyses and transformations for the Toy language. diff --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md new file mode 100644 index 00000000000..615c2c1bbec --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-3.md @@ -0,0 +1,264 @@ +# Chapter 3: High-level Language-Specific Analysis and Transformation + +[TOC] + +Creating a dialect that closely represents the semantics of an input language +enables analyses, transformations and optimizations in MLIR that require +high-level language information and are generally performed on the language AST. +For example, `clang` has a fairly +[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html) +for performing template instantiation in C++. + +We divide compiler transformations into two categories: local and global. In +this chapter, we focus on how to leverage the Toy Dialect and its high-level +semantics to perform local pattern-match transformations that would be difficult +in LLVM. For this, we use MLIR's +[Generic DAG Rewriter](../../GenericDAGRewriter.md). + +There are two methods that can be used to implement pattern-match +transformations: 1. Imperative, C++ pattern-match and rewrite 2. Declarative, +rule-based pattern-match and rewrite using table-driven +[Declarative Rewrite Rules](../../DeclarativeRewrites.md) (DRR). Note that the +use of DRR requires that the operations be defined using ODS, as described in +[Chapter 2](Ch-2.md). + +# Optimize Transpose using C++ style pattern-match and rewrite + +Let's start with a simple pattern and try to eliminate a sequence of two +transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the +corresponding Toy example: + +```Toy(.toy) +def transpose_transpose(x) { + return transpose(transpose(x)); +} +``` + +Which corresponds to the following IR: + +```mlir +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%1) : (tensor<*xf64>) -> () +} +``` + +This is a good example of a transformation that is trivial to match on the Toy +IR but that would be quite hard for LLVM to figure. For example, today Clang +can't optimize away the temporary array, and the computation with the naive +transpose is expressed with these loops: + +```c++ +#define N 100 +#define M 100 + +void sink(void *); +void double_transpose(int A[N][M]) { + int B[M][N]; + for(int i = 0; i < N; ++i) { + for(int j = 0; j < M; ++j) { + B[j][i] = A[i][j]; + } + } + for(int i = 0; i < N; ++i) { + for(int j = 0; j < M; ++j) { + A[i][j] = B[j][i]; + } + } + sink(A); +} +``` + +For a simple C++ approach to rewrite involving matching a tree-like pattern in +the IR and replacing it with a different set of operations, we can plug into the +MLIR `Canonicalizer` pass by implementing a `RewritePattern`: + +```c++ +/// Fold transpose(transpose(x)) -> x +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {} + + /// This method is attempting to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp()); + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; +``` + +The implementation of this rewriter is in `ToyCombine.cpp`. The +[canonicalization pass](../../Canonicalization.md) applies transformations +defined by operations in a greedy, iterative manner. To ensure that the +canonicalization pass applies our new transform, we set +[hasCanonicalizer = 1](../../OpDefinitions.md#hascanonicalizer) and register the +pattern with the canonicalization framework. + +```c++ +// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<SimplifyRedundantTranspose>(context); +} +``` + +We also need to update our main file, `toyc.cpp`, to add an optimization +pipeline. In MLIR, the optimizations are run through a `PassManager` in a +similar way to LLVM: + +```c++ + mlir::PassManager pm(module.getContext()); + pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); +``` + +Finally, we can run `toyc-ch3 test/transpose_transpose.toy -emit=mlir -opt` and +observe our pattern in action: + +```mlir +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%arg0) : (tensor<*xf64>) -> () +} +``` + +As expected, we now directly return the function argument, bypassing any +transpose operation. However, one of the transposes still hasn't been +eliminated. That is not ideal! What happened is that our pattern replaced the +last transform with the function input and left behind the now dead transpose +input. The Canonicalizer knows to clean up dead operations; however, MLIR +conservatively assumes that operations may have side-effects. We can fix this by +adding a new trait, `NoSideEffect`, to our `TransposeOp`: + +```tablegen: +def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...} +``` + +Let's retry now `toyc-ch3 test/transpose_transpose.toy -emit=mlir -opt`: + +```mlir +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> { + "toy.return"(%arg0) : (tensor<*xf64>) -> () +} +``` + +Perfect! No `transpose` operation is left - the code is optimal. + +In the next section, we use DRR for pattern match optimizations associated with +the Reshape op. + +# Optimize Reshapes using DRR + +Declarative, rule-based pattern-match and rewrite (DRR) is an operation +DAG-based declarative rewriter that provides a table-based syntax for +pattern-match and rewrite rules: + +```tablegen: +class Pattern< + dag sourcePattern, list<dag> resultPatterns, + list<dag> additionalConstraints = [], + dag benefitsAdded = (addBenefit 0)>; +``` + +A redundant reshape optimization similar to SimplifyRedundantTranspose can be +expressed more simply using DRR as follows: + +```tablegen: +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; +``` + +The automatically generated C++ code corresponding to each of the DRR patterns +can be found under path/to/BUILD/projects/mlir/examples/toy/Ch3/ToyCombine.inc. + +DRR also provides a method for adding argument constraints when the +transformation is conditional on some properties of the arguments and results. +An example is a transformation that eliminates reshapes when they are redundant, +i.e. when the input and output shapes are identical. + +```tablegen: +def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; +``` + +Some optimizations may require additional transformations on instruction +arguments. This is achieved using NativeCodeCall, which allows for more complex +transformations either by calling into a C++ helper function or by using inline +C++. An example of such an optimization is FoldConstantReshape, where we +optimize Reshape of a constant value by reshaping the constant in place and +eliminating the reshape operation. + +```tablegen: +def ReshapeConstant : NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; +``` + +We demonstrate these reshape optimizations using the following +trivialReshape.toy program: + +```c++ +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} +``` + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>} + : () -> tensor<2xf64> + %1 = "toy.reshape"(%0) : (tensor<2xf64>) -> tensor<2x1xf64> + %2 = "toy.reshape"(%1) : (tensor<2x1xf64>) -> tensor<2x1xf64> + %3 = "toy.reshape"(%2) : (tensor<2x1xf64>) -> tensor<2x1xf64> + "toy.print"(%3) : (tensor<2x1xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +We can try to run `toyc-ch3 test/trivialReshape.toy -emit=mlir -opt` and observe +our pattern in action: + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00], [2.000000e+00]]> \ + : tensor<2x1xf64>} : () -> tensor<2x1xf64> + "toy.print"(%0) : (tensor<2x1xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +As expected, no reshape operations remain after canonicalization. + +Further details on the declarative rewrite method can be found at +[Table-driven Declarative Rewrite Rule (DRR)](../../DeclarativeRewrites.md). + +In this chapter, we saw how to use certain core transformations through always +available hooks. In the [next chapter](Ch-4.md), we will see how to use generic +solutions that scale better through Interfaces. diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md new file mode 100644 index 00000000000..4a4e11c68e6 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -0,0 +1,387 @@ +# Chapter 4: Enabling Generic Transformation with Interfaces + +[TOC] + +## Background: Grappling with an Extensible IR + +Through dialects, MLIR allows for the representation of many different levels of +abstraction; the Toy dialect that we have previously defined is one such +example. Though these different dialects may represent different abstractions, +there is often a set of common transformations and analyses that we would like +to perform. The problem that arises is that naively implementing each +transformation for each dialect leads to large amounts of code duplication, as +the internal algorithms are generally very similar, if not the same. We would +like to provide the ability for transformations to opaquely hook into dialects +like Toy to get the information they need. + +MLIR provides a set of always available-hooks for certain core transformations, +as seen in the [previous chapter](Ch-3.md), where we registered some +canonicalizations via a hook on our operations (`getCanonicalizationPatterns`). +However, these types of hooks don't really scale well. Therefore, a more generic +solution was designed, in the form of [interfaces](../../Interfaces.md), to make +the MLIR infrastructure as extensible as the representation. Interfaces provide +a generic mechanism for dialects and operations to provide information to a +transformation or analysis. + +## Shape Inference: Preparing for Code Generation + +Our Toy IR currently operates on generic tensors, meaning that we don't know the +shape of tensors other than during the initialization of constants. This +complicates optimizations, as well as code generation. Fortunately, we can +simply propagate the shapes through the computation until they are all known. +The issue is how to handle calls to user-defined generic functions: every call +site could deduce different shapes. One possibility would be to perform symbolic +inference based on the argument types, but this would be hard to generalize if +we were to introduce more control flow in the language. Another approach would +be function specialization, where every call site with new argument shapes +duplicates the called function and specializes it. The approach we take for Toy +is to inline all of the function calls, then perform intraprocedural shape +propagation. + +### Inlining + +Here we could write an inlining algorithm specifically designed for the Toy +dialect, but that can become quite complicated depending on the level of +complexity that we want. Disregarding cost modeling, the pure structural +transformation is already complex to implement from scratch. Thankfully, MLIR +provides a generic inliner algorithm that dialects can plug into. All we need to +do in Toy is to provide the [interfaces](../../Interfaces.md) for the inliner to +hook into. + +The first thing we need to do is to define the constraints on inlining +operations in the Toy dialect. This information is provided through a +[dialect interface](../../Interfaces.md#dialect-interfaces). This is essentially +a class containing a set of virtual hooks for which a dialect may provide a +specialization. In this case, the interface is `DialectInlinerInterface`. + +```c++ +/// This class defines the interface for handling inlining with Toy operations. +/// We simplify inherit from the base interface class and provide a +/// specialization of the necessary methods. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// This hook checks to see if the given operation is legal to inline into the + /// given region. For Toy this hook can simply return true, as all Toy + /// operations are inlinable. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + /// This hook is called when a terminator operation has been inlined. The only + /// terminator that we have in the Toy dialect is the return + /// operation(toy.return). We handle the return by replacing the values + /// previously returned by the call operation with the operands of the + /// return. + void handleTerminator(Operation *op, + ArrayRef<Value> valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast<ReturnOp>(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } +}; +``` + +We then register our dialect interface directly on the Toy dialect, similarly to +how we did for operations. + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addInterfaces<ToyInlinerInterface>(); +} +``` + +Next, we need to provide a way for the inliner to know that `toy.generic_call` +represents a call to a function. MLIR provides an +[operation interface](../../Interfaces.md#operation-interfaces) that can be used +to mark an operation as being "call-like". Unlike dialect interfaces, operation +interfaces provide a more refined granularity of information that is specific +and core to a single operation. The interface that we will be adding here is the +`CallOpInterface`. + +To add this interface we just need to include the definition into our operation +specification file (`Ops.td`): + +```tablegen +#ifdef MLIR_CALLINTERFACES +#else +include "mlir/Analysis/CallInterfaces.td" +#endif // MLIR_CALLINTERFACES +``` + +and add it to the traits list of `GenericCallOp`: + +```tablegen +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods<CallOpInterface>]> { + ... +} +``` + +In the above we also use the `DeclareOpInterfaceMethods` directive to +auto-declare all of the interface methods in the class declaration of +GenericCallOp. This means that we just need to provide a definition: + +```c++ +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType<SymbolRefAttr>("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } +``` + +Now that the inliner has been informed about the Toy dialect, we can add the +inliner pass to the pass manager for Toy: + +```c++ + pm.addPass(mlir::createInlinerPass()); +``` + +Now let's look at a working example: + +```mlir +func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> + %2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%2) : (tensor<*xf64>) -> () +} +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> + %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> + %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> + %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} +``` + +We have two calls to multiple_transpose that we would like to inline into main, +but if we look at the output nothing has changed. We are missing one last subtle +piece: there is a hidden type conversion on the edge of the call. If we look at +the above, the operands to the generic_call are of type `tensor<2x3xf64>`, while +the inputs to the function expect `tensor<*xf64>`. To resolve this difference, +the inliner expects an explicit cast operation to be inserted. For this, we need +to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent +casts between two different shapes. + +```tablegen +def CastOp : Toy_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} +``` + +We can then override the necessary hook on the ToyInlinerInterface to insert +this for us when necessary: + +```c++ +struct ToyInlinerInterface : public DialectInlinerInterface { + ... + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + return builder.create<CastOp>(conversionLoc, resultType, input); + } +}; +``` + +If we run the working example through the pipeline again, we get the expected: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.cast"(%1) : (tensor<2x3xf64>) -> tensor<*xf64> + %3 = "toy.cast"(%0) : (tensor<2x3xf64>) -> tensor<*xf64> + %4 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64> + %5 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64> + %6 = "toy.mul"(%4, %5) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.print"(%6) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} +``` + +NOTE: The generic inliner will also perform simplifications, so the output may +be a bit cleaner than expected. + +### Intraprocedural Shape Inference + +Now that we have inlined all of the functions, we are left with a main function +containing a mix of static and dynamically shaped operations. We can now write a +simple shape inference pass to propagate shapes intraprocedurally (within a +single function). We could write this as a pass that directly encodes the +constraints of the operations within the Toy dialect, but this seems like a good +candidate for a transformation that could be written generically. As a good rule +of thumb, it is best to express a transformation as generically as possible, +such that it can be extended to other dialects in the future. There is no +telling how many other dialects may have similar needs or encounter the same +problems. + +For shape inference, if we break down the problem to its core, we really just +want operations to tell us the expected outputs given a set of statically known +inputs. (We can definitely get more complex than that, but for our needs we can +keep it simple.) Given that this property is core to a specific operation, we +can define an operation interface that can be specified on operations that need +to have their result shapes inferred. + +Similarly to operations, we can also +[define operation interfaces](../../OpDefinitions.md#operation-interfaces) using +the operation definition specification (ODS) framework. + +The interface is defined by inheriting from `OpInterface`, which takes the name +to be given to the generated C++ interface class as a template argument. For our +purposes, we will name the generated class a simpler `ShapeInference`. We also +provide a description for the interface. + +```tablegen +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; +} +``` + +Next, we define the interface methods that the operations will need to provide. +An interface method is comprised of: a description; a C++ return type in string +form; a method name in string form; and a few optional components, depending on +the need. See the +[ODS documentation](../../OpDefinitions.md#operation-interfaces) for more +information. + +```tablegen +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} +``` + +Now that the interface is defined, we can add it to the necessary Toy operations +in a similar way to how we added the `CallOpInterface` to the GenericCallOp: + +``` +def MulOp : Toy_Op<"mul", + [..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { + ... +} +``` + +Each of these operations will then need to provide a definition for the +`inferShapes()` method. As an example, for the mul op, the result shape is +inferred as the shape of the inputs. + +```c++ +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } +``` + +At this point, each of the necessary Toy operations provide a mechanism by which +to infer their output shapes. The ShapeInferencePass is a FunctionPass: it will +runs on each Function in isolation. MLIR also supports general +[OperationPasses](../../WritingAPass.md#operation-pass) that run on any isolated +operation (i.e. other function-like operations), but here our module only +contains functions, so there is no need to generalize to all operations. + +Implementing such a pass is done by creating a class inheriting from +`mlir::FunctionPass` and overriding the `runOnFunction()` method: + +```c++ +class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { + void runOnFunction() override { + FuncOp function = getFunction(); + ... + } +}; +``` + +The algorithm operates as follows: + +1. Build a worklist containing all the operations that return a dynamically + shaped tensor: these are the operations that need shape inference. +2. Iterate on the worklist: + - find an operation to process: the next ready operation in the worklist + has all of its arguments non-generic, + - if no operation is found, break out of the loop, + - remove the operation from the worklist, + - infer the shape of its output from the argument types. +3. If the worklist is empty, the algorithm succeeded. + +When processing an operation, we query if it registered the `ShapeInference` +interface. + +```c++ + // Ask the operation to infer its output shapes. + LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + + /// We check if an operation has a particular interface by casting. + if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } +``` + +We can then add our pass to the pass manager: + +```c++ + pm.addPass(mlir::createShapeInferencePass()); +``` + +If we rerun our original example, we now get the following: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%2) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} +``` + +You can build `toyc-ch4` and try yourself: `toyc-ch4 +test/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt`. + +In the [next chapter](Ch-5.md), we will start the process of code generation by +targeting a lower level dialect for optimizing some of the more compute-heavy +Toy operations. diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md new file mode 100644 index 00000000000..8a4268b498f --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -0,0 +1,357 @@ +# Chapter 5: Partial Lowering to Lower-Level Dialects for Optimization + +[TOC] + +At this point, we are eager to generate actual code and see our Toy language +take life. We will use LLVM to generate code, but just showing the LLVM builder +interface here wouldn't be very exciting. Instead, we will show how to perform +progressive lowering through a mix of dialects coexisting in the same function. + +To make it more interesting, in this chapter we will consider that we want to +reuse existing optimizations implemented in a dialect optimizing affine +transformations: `Affine`. This dialect is tailored to the computation-heavy +part of the program and is limited: it doesn't support representing our +`toy.print` builtin, for instance, neither should it! Instead, we can target +`Affine` for the computation heavy part of Toy, and in the +[next chapter](Ch-6.md) directly the `LLVM IR` dialect for lowering `print`. As +part of this lowering, we will be lowering from the +[TensorType](../../LangRef.md#tensor-type) that `Toy` operates on to the +[MemRefType](../../LangRef.md#memref-type) that is indexed via an affine +loop-nest. Tensors represent an abstract value-typed sequence of data, meaning +that they don't live in any memory. MemRefs, on the other hand, represent lower +level buffer access, as they are concrete references to a region of memory. + +# Dialect Conversions + +MLIR has many different dialects, so it is important to have a unified framework +for [converting](../../Glossary.md#conversion) between them. This is where the +`DialectConversion` framework comes into play. This framework allows for +transforming a set of `illegal` operations to a set of `legal` ones. To use this +framework, we need to provide two things (and an optional third): + +* A [Conversion Target](../../DialectConversion.md#conversion-target) + + - This is the formal specification of what operations or dialects are + legal for the conversion. Operations that aren't legal will require + rewrite patterns to perform + [legalization](./../../Glossary.md#legalization). + +* A set of + [Rewrite Patterns](../../DialectConversion.md#rewrite-pattern-specification) + + - These are the set of [patterns](../../QuickstartRewrites.md) used to + convert `illegal` operations into a set of zero or more `legal` ones. + +* Optionally, a [Type Converter](../../DialectConversion.md#type-conversion). + + - If provided, this is used to convert the types of block arguments. We + won't be needing this for our conversion. + +## Conversion Target + +For our purposes, we want to convert the compute-intensive `Toy` operations into +a combination of operations from the `Affine` `Standard` dialects for further +optimization. To start off the lowering, we first define our conversion target: + +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + mlir::ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect<mlir::AffineOpsDialect, mlir::StandardOpsDialect>(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect<ToyDialect>(); + target.addLegalOp<PrintOp>(); + ... +} +``` + +## Conversion Patterns + +After the conversion target has been defined, we can define how to convert the +`illegal` operations into `legal` ones. Similarly to the canonicalization +framework introduced in [chapter 3](Ch-3.md), the +[`DialectConversion` framework](../../DialectConversion.md) also uses +[RewritePatterns](../../QuickstartRewrites.md) to perform the conversion logic. +These patterns may be the `RewritePatterns` seen before or a new type of pattern +specific to the conversion framework `ConversionPattern`. `ConversionPatterns` +are different from traditional `RewritePatterns` in that they accept an +additional `operands` parameter containing operands that have been +remapped/replaced. This is used when dealing with type conversions, as the +pattern will want to operate on values of the new type but match against the +old. For our lowering, this invariant will be useful as it translates from the +[TensorType](../../LangRef.md#tensor-type) currently being operated on to the +[MemRefType](../../LangRef.md#memref-type). Let's look at a snippet of lowering +the `toy.transpose` operation: + +```c++ +/// Lower the `toy.transpose` operation to an affine loop nest. +struct TransposeOpLowering : public mlir::ConversionPattern { + TransposeOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {} + + /// Match and rewrite the given `toy.transpose` operation, with the given + /// operands that have been remapped from `tensor<...>` to `memref<...>`. + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands, + mlir::ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + // Call to a helper function that will lower the current operation to a set + // of affine loops. We provide a functor that operates on the remapped + // operands, as well as the loop induction variables for the inner most + // loop body. + lowerOpToLoops( + op, operands, rewriter, + [loc](mlir::PatternRewriter &rewriter, + ArrayRef<mlir::Value> memRefOperands, + ArrayRef<mlir::Value> loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. This adaptor is automatically provided by the ODS + // framework. + TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + mlir::Value input = transposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; +``` + +Now we can prepare the list of patterns to use during the lowering process: + +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + ... + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + mlir::OwningRewritePatternList patterns; + patterns.insert<..., TransposeOpLowering>(&getContext()); + + ... +``` + +## Partial Lowering + +Once the patterns have been defined, we can perform the actual lowering. The +`DialectConversion` framework provides several different modes of lowering, but, +for our purposes, we will perform a partial lowering, as we will not convert +`toy.print` at this time. + +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + mlir::ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect<mlir::AffineOpsDialect, mlir::StandardOpsDialect>(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect<ToyDialect>(); + target.addLegalOp<PrintOp>(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + mlir::OwningRewritePatternList patterns; + patterns.insert<..., TransposeOpLowering>(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + auto function = getFunction(); + if (mlir::failed(mlir::applyPartialConversion(function, target, patterns))) + signalPassFailure(); +} +``` + +### Design Considerations With Partial Lowering + +Before diving into the result of our lowering, this is a good time to discuss +potential design considerations when it comes to partial lowering. In our +lowering, we transform from a value-type, TensorType, to an allocated +(buffer-like) type, MemRefType. However, given that we do not lower the +`toy.print` operation, we need to temporarily bridge these two worlds. There are +many ways to go about this, each with their own tradeoffs: + +* Generate `load` operations from the buffer + +One option is to generate `load` operations from the buffer type to materialize +an instance of the value type. This allows for the definition of the `toy.print` +operation to remain unchanged. The downside to this approach is that the +optimizations on the `affine` dialect are limited, because the `load` will +actually involve a full copy that is only visible *after* our optimizations have +been performed. + +* Generate a new version of `toy.print` that operates on the lowered type + +Another option would be to have another, lowered, variant of `toy.print` that +operates on the lowered type. The benefit of this option is that there is no +hidden, unnecessary copy to the optimizer. The downside is that another +operation definition is needed that may duplicate many aspects of the first. +Defining a base class in [ODS](../../OpDefinitions.md) may simplify this, but +you still need to treat these operations separately. + +* Update `toy.print` to allow for operating on the lowered type + +A third option is to update the current definition of `toy.print` to allow for +operating the on the lowered type. The benefit of this approach is that it is +simple, does not introduce an additional hidden copy, and does not require +another operation definition. The downside to this option is that it requires +mixing abstraction levels in the `Toy` dialect. + +For the sake of simplicity, we will use the third option for this lowering. This +involves updating the type constraints on the PrintOp in the operation +definition file: + +```tablegen +def PrintOp : Toy_Op<"print"> { + ... + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} +``` + +## Complete Toy Example + +Looking back at our current working example: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%3) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} +``` + +With affine lowering added to our pipeline, we can now generate: + +```mlir +func @main() { + %cst = constant 1.000000e+00 : f64 + %cst_0 = constant 2.000000e+00 : f64 + %cst_1 = constant 3.000000e+00 : f64 + %cst_2 = constant 4.000000e+00 : f64 + %cst_3 = constant 5.000000e+00 : f64 + %cst_4 = constant 6.000000e+00 : f64 + + // Allocating buffers for the inputs and outputs. + %0 = alloc() : memref<3x2xf64> + %1 = alloc() : memref<3x2xf64> + %2 = alloc() : memref<2x3xf64> + + // Initialize the input buffer with the constant values. + affine.store %cst, %2[0, 0] : memref<2x3xf64> + affine.store %cst_0, %2[0, 1] : memref<2x3xf64> + affine.store %cst_1, %2[0, 2] : memref<2x3xf64> + affine.store %cst_2, %2[1, 0] : memref<2x3xf64> + affine.store %cst_3, %2[1, 1] : memref<2x3xf64> + affine.store %cst_4, %2[1, 2] : memref<2x3xf64> + + // Load the transpose value from the input buffer and store it into the + // next input buffer. + affine.for %arg0 = 0 to 3 { + affine.for %arg1 = 0 to 2 { + %3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64> + affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64> + } + } + + // Multiply and store into the output buffer. + affine.for %arg0 = 0 to 2 { + affine.for %arg1 = 0 to 3 { + %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> + %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> + %5 = mulf %3, %4 : f64 + affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64> + } + } + + // Print the value held by the buffer. + "toy.print"(%0) : (memref<3x2xf64>) -> () + dealloc %2 : memref<2x3xf64> + dealloc %1 : memref<3x2xf64> + dealloc %0 : memref<3x2xf64> + return +} +``` + +## Taking Advantage of Affine Optimization + +Our naive lowering is correct, but it leaves a lot to be desired with regards to +efficiency. For example, the lowering of `toy.mul` has generated some redundant +loads. Let's look at how adding a few existing optimizations to the pipeline can +help clean this up. Adding the `LoopFusion` and `MemRefDataFlowOpt` passes to +the pipeline gives the following result: + +```mlir +func @main() { + %cst = constant 1.000000e+00 : f64 + %cst_0 = constant 2.000000e+00 : f64 + %cst_1 = constant 3.000000e+00 : f64 + %cst_2 = constant 4.000000e+00 : f64 + %cst_3 = constant 5.000000e+00 : f64 + %cst_4 = constant 6.000000e+00 : f64 + + // Allocating buffers for the inputs and outputs. + %0 = alloc() : memref<3x2xf64> + %1 = alloc() : memref<2x3xf64> + + // Initialize the input buffer with the constant values. + affine.store %cst, %1[0, 0] : memref<2x3xf64> + affine.store %cst_0, %1[0, 1] : memref<2x3xf64> + affine.store %cst_1, %1[0, 2] : memref<2x3xf64> + affine.store %cst_2, %1[1, 0] : memref<2x3xf64> + affine.store %cst_3, %1[1, 1] : memref<2x3xf64> + affine.store %cst_4, %1[1, 2] : memref<2x3xf64> + + affine.for %arg0 = 0 to 3 { + affine.for %arg1 = 0 to 2 { + // Load the transpose value from the input buffer. + %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> + + // Multiply and store into the output buffer. + %3 = mulf %2, %2 : f64 + affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64> + } + } + + // Print the value held by the buffer. + "toy.print"(%0) : (memref<3x2xf64>) -> () + dealloc %1 : memref<2x3xf64> + dealloc %0 : memref<3x2xf64> + return +} +``` + +Here, we can see that a redundant allocation was removed, the two loop nests +were fused, and some unnecessary `load`s were removed. You can build `toyc-ch5` +and try yourself: `toyc-ch5 test/lowering.toy -emit=mlir-affine`. We can also +check our optimizations by adding `-opt`. + +In this chapter we explored some aspects of partial lowering, with the intent to +optimize. In the [next chapter](Ch-6.md) we will continue the discussion about +dialect conversion by targeting LLVM for code generation. diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md new file mode 100644 index 00000000000..939b2b4f776 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -0,0 +1,323 @@ +# Chapter 6: Lowering to LLVM and CodeGeneration + +[TOC] + +In the [previous chapter](Ch-5.md), we introduced the +[dialect conversion](../../DialectConversion.md) framework and partially lowered +many of the `Toy` operations to affine loop nests for optimization. In this +chapter, we will finally lower to LLVM for code generation. + +# Lowering to LLVM + +For this lowering, we will again use the dialect conversion framework to perform +the heavy lifting. However, this time, we will be performing a full conversion +to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already +lowered all but one of the `toy` operations, with the last being `toy.print`. +Before going over the conversion to LLVM, let's lower the `toy.print` operation. +We will lower this operation to a non-affine loop nest that invokes `printf` for +each element. Note that, because the dialect conversion framework supports +[transitive lowering](Glossary.md#transitive-lowering), we don't need to +directly emit operations in the LLVM dialect. By transitive lowering, we mean +that the conversion framework may apply multiple patterns to fully legalize an +operation. In this example, we are generating a structured loop nest instead of +the branch-form in the LLVM dialect. As long as we then have a lowering from the +loop operations to LLVM, the lowering will still succeed. + +During lowering we can get, or build, the declaration for printf as so: + +```c++ +/// Return a symbol reference to the printf function, inserting it into the +/// module if necessary. +static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { + auto *context = module.getContext(); + if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf")) + return SymbolRefAttr::get("printf", context); + + // Create a function declaration for printf, the signature is: + // * `i32 (i8*, ...)` + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); + + // Insert the printf function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); + return SymbolRefAttr::get("printf", context); +} +``` + +Now that the lowering for the printf operation has been defined, we can specify +the components necessary for the lowering. These are largely the same as the +components defined in the [previous chapter](Ch-5.md). + +## Conversion Target + +For this conversion, aside from the top-level module, we will be lowering +everything to the LLVM dialect. + +```c++ + mlir::ConversionTarget target(getContext()); + target.addLegalDialect<mlir::LLVM::LLVMDialect>(); + target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>(); +``` + +## Type Converter + +This lowering will also transform the MemRef types which are currently being +operated on into a representation in LLVM. To perform this conversion, we use a +TypeConverter as part of the lowering. This converter specifies how one type +maps to another. This is necessary now that we are performing more complicated +lowerings involving block arguments. Given that we don't have any +Toy-dialect-specific types that need to be lowered, the default converter is +enough for our use case. + +```c++ + LLVMTypeConverter typeConverter(&getContext()); +``` + +## Conversion Patterns + +Now that the conversion target has been defined, we need to provide the patterns +used for lowering. At this point in the compilation process, we have a +combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and +`affine` dialects already provide the set of patterns needed to transform them +into LLVM dialect. These patterns allow for lowering the IR in multiple stages +by relying on [transitive lowering](Glossary.md#transitive-lowering). + +```c++ + mlir::OwningRewritePatternList patterns; + mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); + mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); + + // The only remaining operation to lower from the `toy` dialect, is the + // PrintOp. + patterns.insert<PrintOpLowering>(&getContext()); +``` + +## Full Lowering + +We want to completely lower to LLVM, so we use a `FullConversion`. This ensures +that only legal operations will remain after the conversion. + +```c++ + mlir::ModuleOp module = getModule(); + if (mlir::failed(mlir::applyFullConversion(module, target, patterns, + &typeConverter))) + signalPassFailure(); +``` + +Looking back at our current working example: + +```mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%3) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () +} +``` + +We can now lower down to the LLVM dialect, which produces the following code: + +```mlir +llvm.func @free(!llvm<"i8*">) +llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32 +llvm.func @malloc(!llvm.i64) -> !llvm<"i8*"> +llvm.func @main() { + %0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double + %1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double + + ... + +^bb16: + %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %222 = llvm.mlir.constant(0 : index) : !llvm.i64 + %223 = llvm.mlir.constant(2 : index) : !llvm.i64 + %224 = llvm.mul %214, %223 : !llvm.i64 + %225 = llvm.add %222, %224 : !llvm.i64 + %226 = llvm.mlir.constant(1 : index) : !llvm.i64 + %227 = llvm.mul %219, %226 : !llvm.i64 + %228 = llvm.add %225, %227 : !llvm.i64 + %229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*"> + %230 = llvm.load %229 : !llvm<"double*"> + %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 + %232 = llvm.add %219, %218 : !llvm.i64 + llvm.br ^bb15(%232 : !llvm.i64) + + ... + +^bb18: + %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*"> + llvm.call @free(%236) : (!llvm<"i8*">) -> () + %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*"> + llvm.call @free(%238) : (!llvm<"i8*">) -> () + %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> + %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*"> + llvm.call @free(%240) : (!llvm<"i8*">) -> () + llvm.return +} +``` + +See [Conversion to the LLVM IR Dialect](../../ConversionToLLVMDialect.md) for +more in-depth details on lowering to the LLVM dialect. + +# CodeGen: Getting Out of MLIR + +At this point we are right at the cusp of code generation. We can generate code +in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to +run it. + +## Emitting LLVM IR + +Now that our module is comprised only of operations in the LLVM dialect, we can +export to LLVM IR. To do this programmatically, we can invoke the following +utility: + +```c++ + std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) + /* ... an error was encountered ... */ +``` + +Exporting our module to LLVM IR generates: + +```.llvm +define void @main() { + ... + +102: + %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 + %104 = mul i64 %96, 2 + %105 = add i64 0, %104 + %106 = mul i64 %100, 1 + %107 = add i64 %105, %106 + %108 = getelementptr double, double* %103, i64 %107 + %109 = load double, double* %108 + %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) + %111 = add i64 %100, 1 + br label %99 + + ... + +115: + %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0 + %117 = bitcast double* %116 to i8* + call void @free(i8* %117) + %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0 + %119 = bitcast double* %118 to i8* + call void @free(i8* %119) + %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 + %121 = bitcast double* %120 to i8* + call void @free(i8* %121) + ret void +} +``` + +If we enable optimization on the generated LLVM IR, we can trim this down quite +a bit: + +```.llvm +define void @main() + %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00) + %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01) + %putchar = tail call i32 @putchar(i32 10) + %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00) + %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01) + %putchar.1 = tail call i32 @putchar(i32 10) + %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00) + %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01) + %putchar.2 = tail call i32 @putchar(i32 10) + ret void +} + +``` + +The full code listing for dumping LLVM IR can be found in `Ch6/toy.cpp` in the +`dumpLLVMIR()` function: + +```c++ + +int dumpLLVMIR(mlir::ModuleOp module) { + // Translate the module, that contains the LLVM dialect, to LLVM IR. + auto llvmModule = mlir::translateModuleToLLVMIR(module); + if (!llvmModule) { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} +``` + +## Setting up a JIT + +Setting up a JIT to run the module containing the LLVM dialect can be done using +the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around +LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up +the JIT can be found in `Ch6/toy.cpp` in the `runJit()` function: + +```c++ +int runJit(mlir::ModuleOp module) { + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invoke("main"); + if (invocationResult) { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} +``` + +You can play around with it from the build directory: + +```sh +$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit +1.000000 2.000000 +3.000000 4.000000 +``` + +You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and +`-emit=llvm` to compare the various levels of IR involved. Also try options like +[`--print-ir-after-all`](../../WritingAPass.md#ir-printing) to track the +evolution of the IR throughout the pipeline. + +So far, we have worked with primitive data types. In the +[next chapter](Ch-7.md), we will add a composite `struct` type. diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md new file mode 100644 index 00000000000..6298e8253e9 --- /dev/null +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -0,0 +1,539 @@ +# Chapter 7: Adding a Composite Type to Toy + +[TOC] + +In the [previous chapter](Ch-6.md), we demonstrated an end-to-end compilation +flow from our Toy front-end to LLVM IR. In this chapter, we will extend the Toy +language to support a new composite `struct` type. + +## Defining a `struct` in Toy + +The first thing we need to define is the interface of this type in our `toy` +source language. The general syntax of a `struct` type in Toy is as follows: + +```toy +# A struct is defined by using the `struct` keyword followed by a name. +struct MyStruct { + # Inside of the struct is a list of variable declarations without initializers + # or shapes, which may also be other previously defined structs. + var a; + var b; +} +``` + +Structs may now be used in functions as variables or parameters by using the +name of the struct instead of `var`. The members of the struct are accessed via +a `.` access operator. Values of `struct` type may be initialized with a +composite initializer, or a comma-separated list of other initializers +surrounded by `{}`. An example is shown below: + +```toy +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} +``` + +## Defining a `struct` in MLIR + +In MLIR, we will also need a representation for our struct types. MLIR does not +provide a type that does exactly what we need, so we will need to define our +own. We will simply define our `struct` as an unnamed container of a set of +element types. The name of the `struct` and its elements are only useful for the +AST of our `toy` compiler, so we don't need to encode it in the MLIR +representation. + +### Defining the Type Class + +#### Reserving a Range of Type Kinds + +Types in MLIR rely on having a unique `kind` value to ensure that casting checks +remain extremely efficient +([rationale](../../Rationale.md#reserving-dialect-type-kinds)). For `toy`, this +means we need to explicitly reserve a static range of type `kind` values in the +symbol registry file +[DialectSymbolRegistry](https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/DialectSymbolRegistry.def). + +```c++ +DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect +DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect + +// The following ranges are reserved for experimenting with MLIR dialects in a +// private context without having to register them here. +DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0) +``` + +These definitions will provide a range in the Type::Kind enum to use when +defining the derived types. + +```c++ +/// Create a local enumeration with all of the types that are defined by Toy. +namespace ToyTypes { +enum Types { + Struct = mlir::Type::FIRST_TOY_TYPE, +}; +} // end namespace ToyTypes +``` + +#### Defining the Type Class + +As mentioned in [chapter 2](Ch-2.md), [`Type`](../../LangRef.md#type-system) +objects in MLIR are value-typed and rely on having an internal storage object +that holds the actual data for the type. The `Type` class in itself acts as a +simple wrapper around an internal `TypeStorage` object that is uniqued within an +instance of an `MLIRContext`. When constructing a `Type`, we are internally just +constructing and uniquing an instance of a storage class. + +When defining a new `Type` that requires additional information beyond just the +`kind` (e.g. the `struct` type, which requires additional information to hold +the element types), we will need to provide a derived storage class. The +`primitive` types that don't have any additional data (e.g. the +[`index` type](../../LangRef.md#index-type)) don't require a storage class. + +##### Defining the Storage Class + +Type storage objects contain all of the data necessary to construct and unique a +type instance. Derived storage classes must inherit from the base +`mlir::TypeStorage` and provide a set of aliases and hooks that will be used by +the `MLIRContext` for uniquing. Below is the definition of the storage instance +for our `struct` type, with each of the necessary requirements detailed inline: + +```c++ +/// This class represents the internal storage of the Toy `StructType`. +struct StructTypeStorage : public mlir::TypeStorage { + /// The `KeyTy` is a required type that provides an interface for the storage + /// instance. This type will be used when uniquing an instance of the type + /// storage. For our struct type, we will unique each instance structurally on + /// the elements that it contains. + using KeyTy = llvm::ArrayRef<mlir::Type>; + + /// A constructor for the type storage instance. + StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) + : elementTypes(elementTypes) {} + + /// Define the comparison function for the key type with the current storage + /// instance. This is used when constructing a new instance to ensure that we + /// haven't already uniqued an instance of the given key. + bool operator==(const KeyTy &key) const { return key == elementTypes; } + + /// Define a hash function for the key type. This is used when uniquing + /// instances of the storage. + /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type + /// have hash functions available, so we could just omit this entirely. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Define a construction function for the key type from a set of parameters. + /// These parameters will be provided when constructing the storage instance + /// itself, see the `StructType::get` method further below. + /// Note: This method isn't necessary because KeyTy can be directly + /// constructed with the given parameters. + static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { + return KeyTy(elementTypes); + } + + /// Define a construction method for creating a new instance of this storage. + /// This method takes an instance of a storage allocator, and an instance of a + /// `KeyTy`. The given allocator must be used for *all* necessary dynamic + /// allocations used to create the type storage and its internal. + static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the elements from the provided `KeyTy` into the allocator. + llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); + + // Allocate the storage instance and construct it. + return new (allocator.allocate<StructTypeStorage>()) + StructTypeStorage(elementTypes); + } + + /// The following field contains the element types of the struct. + llvm::ArrayRef<mlir::Type> elementTypes; +}; +``` + +##### Defining the Type Class + +With the storage class defined, we can add the definition for the user-visible +`StructType` class. This is the class that we will actually interface with. + +```c++ +/// This class defines the Toy struct type. It represents a collection of +/// element types. All derived types in MLIR must inherit from the CRTP class +/// 'Type::TypeBase'. It takes as template parameters the concrete type +/// (StructType), the base class to use (Type), and the storage class +/// (StructTypeStorage). +class StructType : public mlir::Type::TypeBase<StructType, mlir::Type, + StructTypeStorage> { +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; } + + /// Create an instance of a `StructType` with the given element types. There + /// *must* be at least one element type. + static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) { + assert(!elementTypes.empty() && "expected at least 1 element type"); + + // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance + // of this type. The first two parameters are the context to unique in and + // the kind of the type. The parameters after the type kind are forwarded to + // the storage instance. + mlir::MLIRContext *ctx = elementTypes.front().getContext(); + return Base::get(ctx, ToyTypes::Struct, elementTypes); + } + + /// Returns the element types of this struct type. + llvm::ArrayRef<mlir::Type> getElementTypes() { + // 'getImpl' returns a pointer to the internal storage instance. + return getImpl()->elementTypes; + } + + /// Returns the number of element type held by this struct. + size_t getNumElementTypes() { return getElementTypes().size(); } +}; +``` + +We register this type in the `ToyDialect` constructor in a similar way to how we +did with operations: + +```c++ +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx) { + addTypes<StructType>(); +} +``` + +With this we can now use our `StructType` when generating MLIR from Toy. See +examples/toy/Ch7/mlir/MLIRGen.cpp for more details. + +### Parsing and Printing + +At this point we can use our `StructType` during MLIR generation and +transformation, but we can't output or parse `.mlir`. For this we need to add +support for parsing and printing instances of the `StructType`. This can be done +by overriding the `parseType` and `printType` methods on the `ToyDialect`. + +```c++ +class ToyDialect : public mlir::Dialect { +public: + /// Parse an instance of a type registered to the toy dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + + /// Print an instance of a type registered to the toy dialect. + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const override; +}; +``` + +These methods take an instance of a high-level parser or printer that allows for +easily implementing the necessary functionality. Before going into the +implementation, let's think about the syntax that we want for the `struct` type +in the printed IR. As described in the +[MLIR language reference](../../LangRef.md#dialect-types), dialect types are +generally represented as: `! dialect-namespace < type-data >`, with a pretty +form available under certain circumstances. The responsibility of our `Toy` +parser and printer is to provide the `type-data` bits. We will define our +`StructType` as having the following form: + +``` + struct-type ::= `struct` `<` type (`,` type)* `>` +``` + +#### Parsing + +An implementation of the parser is shown below: + +```c++ +/// Parse an instance of a type registered to the toy dialect. +mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { + // Parse a struct type in the following form: + // struct-type ::= `struct` `<` type (`,` type)* `>` + + // NOTE: All MLIR parser function return a ParseResult. This is a + // specialization of LogicalResult that auto-converts to a `true` boolean + // value on failure to allow for chaining, but may be used with explicit + // `mlir::failed/mlir::succeeded` as desired. + + // Parse: `struct` `<` + if (parser.parseKeyword("struct") || parser.parseLess()) + return Type(); + + // Parse the element types of the struct. + SmallVector<mlir::Type, 1> elementTypes; + do { + // Parse the current element type. + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + mlir::Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + // Check that the type is either a TensorType or another StructType. + if (!elementType.isa<mlir::TensorType>() && + !elementType.isa<StructType>()) { + parser.emitError(typeLoc, "element type for a struct must either " + "be a TensorType or a StructType, got: ") + << elementType; + return Type(); + } + elementTypes.push_back(elementType); + + // Parse the optional: `,` + } while (succeeded(parser.parseOptionalComma())); + + // Parse: `>` + if (parser.parseGreater()) + return Type(); + return StructType::get(elementTypes); +} +``` + +#### Printing + +An implementation of the printer is shown below: + +```c++ +/// Print an instance of a type registered to the toy dialect. +void ToyDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // Currently the only toy type is a struct type. + StructType structType = type.cast<StructType>(); + + // Print the struct type according to the parser format. + printer << "struct<"; + mlir::interleaveComma(structType.getElementTypes(), printer); + printer << '>'; +} +``` + +Before moving on, let's look at a quick of example showcasing the functionality +we have now: + +```toy +struct Struct { + var a; + var b; +} + +def multiply_transpose(Struct value) { +} +``` + +Which generates the following: + +```mlir +module { + func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) { + "toy.return"() : () -> () + } +} +``` + +### Operating on `StructType` + +Now that the `struct` type has been defined, and we can round-trip it through +the IR. The next step is to add support for using it within our operations. + +#### Updating Existing Operations + +A few of our existing operations will need to be updated to handle `StructType`. +The first step is to make the ODS framework aware of our Type so that we can use +it in the operation definitions. A simple example is shown below: + +```tablegen +// Provide a definition for the Toy StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. +def Toy_StructType : + Type<CPred<"$_self.isa<StructType>()">, "Toy struct type">; + +// Provide a definition of the types that are used within the Toy dialect. +def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; +``` + +We can then update our operations, e.g. `ReturnOp`, to also accept the +`Toy_StructType`: + +```tablegen +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + ... + let arguments = (ins Variadic<Toy_Type>:$input); + ... +} +``` + +#### Adding New `Toy` Operations + +In addition to the existing operations, we will be adding a few new operations +that will provide more specific handling of `structs`. + +##### `toy.struct_constant` + +This new operation materializes a constant value for a struct. In our current +modeling, we just use an [array attribute](../../LangRef.md#array-attribute) +that contains a set of constant values for each of the `struct` elements. + +```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct<tensor<*xf64>> +``` + +##### `toy.struct_access` + +This new operation materializes the Nth element of a `struct` value. + +```mlir + %0 = "toy.struct_constant"() { + value = [dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>] + } : () -> !toy.struct<tensor<*xf64>> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>>) -> tensor<*xf64> +``` + +With these operations, we can revisit our original example: + +```toy +struct Struct { + var a; + var b; +} + +# User defined generic function may operate on struct types as well. +def multiply_transpose(Struct value) { + # We can access the elements of a struct via the '.' operator. + return transpose(value.a) * transpose(value.b); +} + +def main() { + # We initialize struct values using a composite initializer. + Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; + + # We pass these arguments to functions like we do with variables. + var c = multiply_transpose(value); + print(c); +} +``` + +and finally get a full MLIR module: + +```mlir +module { + func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> { + %0 = "toy.struct_access"(%arg0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + %2 = "toy.struct_access"(%arg0) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %3 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64> + %4 = "toy.mul"(%1, %3) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%4) : (tensor<*xf64>) -> () + } + func @main() { + %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>> + %1 = "toy.generic_call"(%0) {callee = @multiply_transpose} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + "toy.print"(%1) : (tensor<*xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +#### Optimizing Operations on `StructType` + +Now that we have a few operations operating on `StructType`, we also have many +new constant folding opportunities. + +After inlining, the MLIR module in the previous section looks something like: + +```mlir +module { + func @main() { + %0 = "toy.struct_constant"() {value = [dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>]} : () -> !toy.struct<tensor<*xf64>, tensor<*xf64>> + %1 = "toy.struct_access"(%0) {index = 0 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %2 = "toy.transpose"(%1) : (tensor<*xf64>) -> tensor<*xf64> + %3 = "toy.struct_access"(%0) {index = 1 : i64} : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> + %4 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64> + %5 = "toy.mul"(%2, %4) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +We have several `toy.struct_access` operations that access into a +`toy.struct_constant`. As detailed in [chapter 3](Ch-3.md), we can add folders +for these `toy` operations by setting the `hasFolder` bit on the operation +definition and providing a definition of the `*Op::fold` method. + +```c++ +/// Fold constants. +OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); } + +/// Fold struct constants. +OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) { + return value(); +} + +/// Fold simple struct access operations that access into a constant. +OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) { + auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>(); + if (!structAttr) + return nullptr; + + size_t elementIndex = index().getZExtValue(); + return structAttr.getValue()[elementIndex]; +} +``` + +To ensure that MLIR generates the proper constant operations when folding our +`Toy` operations, i.e. `ConstantOp` for `TensorType` and `StructConstant` for +`StructType`, we will need to provide an override for the dialect hook +`materializeConstant`. This allows for generic MLIR operations to create +constants for the `Toy` dialect when necessary. + +```c++ +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa<StructType>()) + return builder.create<StructConstantOp>(loc, type, + value.cast<mlir::ArrayAttr>()); + return builder.create<ConstantOp>(loc, type, + value.cast<mlir::DenseElementsAttr>()); +} +``` + +With this, we can now generate code that can be generated to LLVM without any +changes to our pipeline. + +```mlir +module { + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64> + %2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64> + "toy.print"(%2) : (tensor<3x2xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +You can build `toyc-ch7` and try yourself: `toyc-ch7 +test/Examples/Toy/Ch7/struct-codegen.toy -emit=mlir`. More details on defining +custom types can be found in +[DefiningAttributesAndTypes](../../DefiningAttributesAndTypes.md). |