diff options
Diffstat (limited to 'mlir/docs/GenericDAGRewriter.md')
-rw-r--r-- | mlir/docs/GenericDAGRewriter.md | 415 |
1 files changed, 415 insertions, 0 deletions
diff --git a/mlir/docs/GenericDAGRewriter.md b/mlir/docs/GenericDAGRewriter.md new file mode 100644 index 00000000000..8cc09f7d17f --- /dev/null +++ b/mlir/docs/GenericDAGRewriter.md @@ -0,0 +1,415 @@ +# MLIR Generic DAG Rewriter Infrastructure + +## Introduction and Motivation + +The goal of a compiler IR is to represent code - at various levels of +abstraction which pose different sets of tradeoffs in terms of representational +capabilities and ease of transformation. However, the ability to represent code +is not itself very useful - you also need to be able to implement those +transformations. + +There are many different sorts of compiler transformations, but this document +focuses on a particularly important class of transformation that comes up +repeatedly at scale, and is important for the immediate goals of MLIR: that of +pattern matching on a set of operations and replacing with another set. This is +the key algorithm required to implement the "op fission" algorithm used by the +tf2xla bridge, pattern matching rewrites from TF ops to TF/Lite, peephole +optimizations like "eliminate identity nodes" or "replace x+0 with x", as well +as a useful abstraction to implement optimization algorithms for MLIR graphs at +all levels. + +A particular strength of MLIR (and a major difference vs other compiler +infrastructures like LLVM, GCC, XLA, TensorFlow, etc) is that it uses a single +compiler IR to represent code at multiple levels of abstraction: an MLIR +operation can be a "TensorFlow operation", an "XLA HLO", a "TF Lite +FlatBufferModel op", a TPU LLO instruction, an LLVM IR instruction (transitively +including X86, Lanai, CUDA, and other target specific instructions), or anything +else that the MLIR type system can reasonably express. Because MLIR spans such a +wide range of different problems, a single infrastructure for performing +graph-to-graph rewrites can help solve many diverse domain challenges, including +TensorFlow graph level down to the machine code level. + +[Static single assignment](https://en.wikipedia.org/wiki/Static_single_assignment_form) +(SSA) representations like MLIR make it easy to access the operands and "users" +of an operation. As such, a natural abstraction for these graph-to-graph +rewrites is that of DAG pattern matching: clients define DAG tile patterns, and +each pattern includes a result DAG to produce and the cost of the result (or, +inversely, the benefit of doing the replacement). A common infrastructure +efficiently finds and perform the rewrites. + +While this concept is simple, the details are more nuanced. This proposal +defines and explores a set of abstractions that we feel can solve a wide range +of different problems, and can be applied to many different sorts of problems +that MLIR is - and is expected to - face over time. We do this by separating the +pattern definition and matching algorithm from the "driver" of the computation +loop, and make space for the patterns to be defined declaratively in the future. + +## Related Work + +There is a huge amount of related work to consider, given that pretty much every +compiler in existence has to solve this problem many times over. Here are a few +graph rewrite systems we have used, along with the pros and cons of this related +work. One unifying problem with all of these is that these systems are only +trying to solve one particular and usually narrow problem: our proposal would +like to solve many of these problems with a single infrastructure. Of these, the +most similar design to our proposal is the LLVM DAG-to-DAG instruction selection +algorithm at the end. + +### Constant folding + +A degenerate but pervasive case of DAG-to-DAG pattern matching is constant +folding: given an operation whose operands contain constants can often be folded +to a result constant value. + +MLIR already has constant folding routines which provide a simpler API than a +general DAG-to-DAG pattern matcher, and we expect it to remain because the +simpler contract makes it applicable in some cases that a generic matcher would +not. For example, a DAG-rewrite can remove arbitrary nodes in the current +function, which could invalidate iterators. Constant folding as an API does not +remove any nodes, it just provides a (list of) constant values and allows the +clients to update their data structures as necessary. + +### AST-Level Pattern Matchers + +The literature is full of source-to-source translators which transform +identities in order to improve performance (e.g. transforming `X*0` into `0`). +One large example that I'm aware of is the GCC `fold` function, which performs +[many optimizations](https://github.com/gcc-mirror/gcc/blob/master/gcc/fold-const.c) +on ASTs. Clang has +[similar routines](http://releases.llvm.org/3.5.0/tools/clang/docs/InternalsManual.html#constant-folding-in-the-clang-ast) +for simple constant folding of expressions (as required by the C++ standard) but +doesn't perform general optimizations on its ASTs. + +The primary downside of tree optimizers is that you can't see across operations +that have multiple uses. It is +[well known in literature](https://llvm.org/pubs/2008-06-LCTES-ISelUsingSSAGraphs.pdf) +that DAG pattern matching is more powerful than tree pattern matching, but OTOH, +DAG pattern matching can lead to duplication of computation which needs to be +checked for. + +### "Combiners" and other peephole optimizers + +Compilers end up with a lot of peephole optimizers for various things, e.g. the +GCC +["combine" routines](https://github.com/gcc-mirror/gcc/blob/master/gcc/combine.c) +(which try to merge two machine instructions into a single one), the LLVM +[Inst Combine](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/) +[pass](https://llvm.org/docs/Passes.html#instcombine-combine-redundant-instructions), +LLVM's +[DAG Combiner](https://github.com/llvm-mirror/llvm/blob/master/lib/CodeGen/SelectionDAG/DAGCombiner.cpp), +the Swift compiler's +[SIL Combiner](https://github.com/apple/swift/tree/master/lib/SILOptimizer/SILCombiner), +etc. These generally match one or more operations and produce zero or more +operations as a result. The LLVM +[Legalization](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/) +infrastructure has a different outer loop but otherwise works the same way. + +These passes have a lot of diversity, but also have a unifying structure: they +mostly have a worklist outer loop which visits operations. They then use the C++ +visitor pattern (or equivalent) to switch over the class of operation and +dispatch to a method. That method contains a long list of hand-written C++ code +that pattern-matches various special cases. LLVM introduced a "match" function +that allows writing patterns in a somewhat more declarative style using template +metaprogramming (MLIR has similar facilities). Here's a simple example: + +```c++ + // Y - (X + 1) --> ~X + Y + if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One())))) + return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0); +``` + +Here is a somewhat more complicated one (this is not the biggest or most +complicated :) + +```c++ + // C2 is ODD + // LHS = XOR(Y,C1), Y = AND(Z,C2), C1==(C2+1) => LHS == NEG(OR(Z, ~C2)) + // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2)) + if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) + if (C1->countTrailingZeros() == 0) + if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { + Value NewOr = Builder.CreateOr(Z, ~(*C2)); + return Builder.CreateSub(RHS, NewOr, "sub"); + } +``` + +These systems are simple to set up, and pattern matching templates have some +advantages (they are extensible for new sorts of sub-patterns, look compact at +point of use). OTOH, they have lots of well known problems, for example: + +* These patterns are very error prone to write, and contain lots of + redundancies. +* The IR being matched often has identities (e.g. when matching commutative + operators) and the C++ code has to handle it manually - take a look at + [the full code](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp?view=markup#l775) + for checkForNegativeOperand that defines the second pattern). +* The matching code compiles slowly, both because it generates tons of code + and because the templates instantiate slowly. +* Adding new patterns (e.g. for count leading zeros in the example above) is + awkward and doesn't often happen. +* The cost model for these patterns is not really defined - it is emergent + based on the order the patterns are matched in code. +* They are non-extensible without rebuilding the compiler. +* It isn't practical to apply theorem provers and other tools to these + patterns - they cannot be reused for other purposes. + +In addition to structured "combiners" like these, there are lots of ad-hoc +systems like the +[LLVM Machine code peephole optimizer](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/PeepholeOptimizer.cpp?view=markup) +which are related. + +### LLVM's DAG-to-DAG Instruction Selection Infrastructure + +The instruction selection subsystem in LLVM is the result of many years worth of +iteration and discovery, driven by the need for LLVM to support code generation +for lots of targets, the complexity of code generators for modern instruction +sets (e.g. X86), and the fanatical pursuit of reusing code across targets. Eli +wrote a +[nice short overview](https://eli.thegreenplace.net/2013/02/25/a-deeper-look-into-the-llvm-code-generator-part-1) +of how this works, and the +[LLVM documentation](https://llvm.org/docs/CodeGenerator.html#select-instructions-from-dag) +describes it in more depth including its advantages and limitations. It allows +writing patterns like this. + +``` +def : Pat<(or GR64:$src, (not (add GR64:$src, 1))), + (BLCI64rr GR64:$src)>; +``` + +This example defines a matcher for the +["blci" instruction](https://en.wikipedia.org/wiki/Bit_Manipulation_Instruction_Sets#TBM_\(Trailing_Bit_Manipulation\)) +in the +[X86 target description](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrInfo.td?view=markup), +there are many others in that file (look for `Pat<>` patterns, since they aren't +entangled in details of the compiler like assembler/disassembler generation +logic). + +For our purposes, there is much to like about this system, for example: + +* It is defined in a declarative format. +* It is extensible to target-defined operations. +* It automates matching across identities, like commutative patterns. +* It allows custom abstractions and intense factoring of target-specific + commonalities. +* It generates compact code - it compiles into a state machine, which is + interpreted. +* It allows the instruction patterns to be defined and reused for multiple + purposes. +* The patterns are "type checked" at compile time, detecting lots of bugs + early and eliminating redundancy from the pattern specifications. +* It allows the use of general C++ code for weird/complex cases. + +While there is a lot that is good here, there is also a lot of bad things: + +* All of this machinery is only applicable to instruction selection. Even + directly adjacent problems like the DAGCombiner and Legalizer can't use it. +* This isn't extensible at compiler runtime, you have to rebuild the compiler + to extend it. +* The error messages when failing to match a pattern + [are not exactly optimal](https://www.google.com/search?q=llvm+cannot+select). +* It has lots of implementation problems and limitations (e.g. can't write a + pattern for a multi-result operation) as a result of working with the + awkward SelectionDAG representation and being designed and implemented + lazily. +* This stuff all grew organically over time and has lots of sharp edges. + +### Summary + +MLIR will face a wide range of pattern matching and graph rewrite problems, and +one of the major advantages of having a common representation for code at +multiple levels that it allows us to invest in - and highly leverage - a single +infra for doing this sort of work. + +## Goals + +This proposal includes support for defining pattern matching and rewrite +algorithms on MLIR. We'd like these algorithms to encompass many problems in the +MLIR space, including 1-to-N expansions (e.g. as seen in the TF/XLA bridge when +lowering a "tf.AddN" to multiple "add" HLOs), M-to-1 patterns (as seen in +Grappler optimization passes, e.g. that convert multiple/add into a single +muladd op), as well as general M-to-N patterns (e.g. instruction selection for +target instructions). Patterns should have a cost associated with them, and the +common infrastructure should be responsible for sorting out the lowest cost +match for a given application. + +We separate the task of picking a particular locally optimal pattern from a +given root node, the algorithm used to rewrite an entire graph given a +particular set of goals, and the definition of the patterns themselves. We do +this because DAG tile pattern matching is NP complete, which means that there +are no known polynomial time algorithms to optimally solve this problem. +Additionally, we would like to support iterative rewrite algorithms that +progressively transform the input program through multiple steps. Furthermore, +we would like to support many different sorts of clients across the MLIR stack, +and they may have different tolerances for compile time cost, different demands +for optimality, and other algorithmic goals or constraints. + +We aim for MLIR transformations to be easy to implement and reduce the +likelihood for compiler bugs. We expect there to be a very very large number of +patterns that are defined over time, and we believe that these sorts of patterns +will have a very large number of legality/validity constraints - many of which +are difficult to reason about in a consistent way, may be target specific, and +whose implementation may be particularly bug-prone. As such, we aim to design the +API around pattern definition to be simple, resilient to programmer errors, and +allow separation of concerns between the legality of the nodes generated from +the idea of the pattern being defined. + +Finally, error handling is a topmost concern: in addition to allowing patterns +to be defined in a target-independent way that may not apply for all hardware, +we also want failure for any pattern to match to be diagnosable in a reasonable +way. To be clear, this is not a solvable problem in general - the space of +malfunction is too great to be fully enumerated and handled optimally, but there +are better and worse ways to handle the situation. MLIR is already designed to +represent the provenance of an operation well. This project aims to propagate +that provenance information precisely, as well as diagnose pattern match +failures with the rationale for why a set of patterns do not apply. + +### Non goals + +This proposal doesn't aim to solve all compiler problems, it is simply a +DAG-to-DAG pattern matching system, starting with a greedy driver algorithm. +Compiler algorithms that require global dataflow analysis (e.g. common +subexpression elimination, conditional constant propagation, and many many +others) will not be directly solved by this infrastructure. + +This proposal is limited to DAG patterns, which (by definition) prevent the +patterns from seeing across cycles in a graph. In an SSA-based IR like MLIR, +this means that these patterns don't see across PHI nodes / basic block +arguments. We consider this acceptable given the set of problems we are trying +to solve - we don't know of any other system that attempts to do so, and +consider the payoff of worrying about this to be low. + +This design includes the ability for DAG patterns to have associated costs +(benefits), but those costs are defined in terms of magic numbers (typically +equal to the number of nodes being replaced). For any given application, the +units of magic numbers will have to be defined. + +## Overall design + +We decompose the problem into four major pieces: + +1. the code that is used to define patterns to match, cost, and their + replacement actions +1. the driver logic to pick the best match for a given root node +1. the client that is implementing some transformation (e.g. a combiner) +1. (future) the subsystem that allows patterns to be described with a + declarative syntax, which sugars step #1. + +We sketch the first three of these pieces, each in turn. This is not intended to +be a concrete API proposal, merely to describe the design + +### Defining Patterns + +Each pattern will be an instance of a mlir::Pattern class, whose subclasses +implement methods like this. Note that this API is meant for exposition, the +actual details are different for efficiency and coding standards reasons (e.g. +the memory management of `PatternState` is not specified below, etc): + +```c++ +class Pattern { + /// Return the benefit (the inverse of "cost") of matching this pattern. The + /// benefit of a Pattern is always static - rewrites that may have dynamic + /// benefit can be instantiated multiple times (different Pattern instances) + /// for each benefit that they may return, and be guarded by different match + /// condition predicates. + PatternBenefit getBenefit() const { return benefit; } + + /// Return the root node that this pattern matches. Patterns that can + /// match multiple root types are instantiated once per root. + OperationName getRootKind() const { return rootKind; } + + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). On failure, this + /// returns a None value. On success it a (possibly null) pattern-specific + /// state wrapped in a Some. This state is passed back into its rewrite + /// function if this match is selected. + virtual Optional<PatternState*> match(Operation *op) const = 0; + + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// rewriter. If an unexpected error is encountered (an internal + /// compiler error), it is emitted through the normal MLIR diagnostic + /// hooks and the IR is left in a valid state. + virtual void rewrite(Operation *op, PatternState *state, + PatternRewriter &rewriter) const; +}; +``` + +In practice, the first patterns we implement will directly subclass and +implement this stuff, but we will define some helpers to reduce boilerplate. +When we have a declarative way to describe patterns, this should be +automatically generated from the description. + +Instances of `Pattern` have a benefit that is static upon construction of the +pattern instance, but may be computed dynamically at pattern initialization +time, e.g. allowing the benefit to be derived from domain specific information, +like the target architecture). This limitation allows us MLIR to (eventually) +perform pattern fusion and compile patterns into an efficient state machine, and +[Thier, Ertl, and Krall](https://dl.acm.org/citation.cfm?id=3179501) have shown +that match predicates eliminate the need for dynamically computed costs in +almost all cases: you can simply instantiate the same pattern one time for each +possible cost and use the predicate to guard the match. + +The two-phase nature of this API (match separate from rewrite) is important for +two reasons: 1) some clients may want to explore different ways to tile the +graph, and only rewrite after committing to one tiling. 2) We want to support +runtime extensibility of the pattern sets, but want to be able to statically +compile the bulk of known patterns into a state machine at "compiler compile +time". Both of these reasons lead to us needing to match multiple patterns +before committing to an answer. + +### Picking and performing a replacement + +In the short term, this API can be very simple, something like this can work and +will be useful for many clients: + +```c++ +class PatternMatcher { + // Create a pattern matcher with a bunch of patterns. This constructor + // looks across all of the specified patterns, and builds an internal + // data structure that allows efficient matching. + PatternMatcher(ArrayRef<Pattern*> patterns); + + // Given a specific operation, see if there is some rewrite that is + // interesting. If so, return success and return the list of new + // operations that were created. If not, return failure. + bool matchAndRewrite(Operation *op, + SmallVectorImpl<Operation*> &newlyCreatedOps); +}; +``` + +In practice the interesting part of this class is the acceleration structure it +builds internally. It buckets up the patterns by root operation, and sorts them +by their static benefit. When performing a match, it tests any dynamic patterns, +then tests statically known patterns from highest to lowest benefit. + +### First Client: A Greedy Worklist Combiner + +We expect that there will be lots of clients for this, but a simple greedy +worklist-driven combiner should be powerful enough to serve many important ones, +including the +[TF2XLA op expansion logic](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla/kernels), +many of the pattern substitution passes of the +[TOCO compiler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco) +for TF-Lite, many +[Grappler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/grappler) +passes, and other general performance optimizations for applying identities. + +The structure of this algorithm is straight-forward, here is pseudo code: + +* Walk a function in preorder, adding each operation to a worklist. +* While the worklist is non-empty, pull something off the back (processing + things generally in postorder) + * Perform matchAndRewrite on the operation. If failed, continue to the + next operation. + * On success, add the newly created ops to the worklist and continue. + +## Future directions + +It is important to get implementation and usage experience with this, and many +patterns can be defined using this sort of framework. Over time, we can look to +make it easier to declare patterns in a declarative form (e.g. with the LLVM +tblgen tool or something newer/better). Once we have that, we can define an +internal abstraction for describing the patterns to match, allowing better high +level optimization of patterns (including fusion of the matching logic across +patterns, which the LLVM instruction selector does) and allow the patterns to be +defined without rebuilding the compiler itself. |