1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
|
//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This transformation pass converts operations into their canonical forms by
// folding constants, applying operation identity transformations etc.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
#include <memory>
using namespace mlir;
//===----------------------------------------------------------------------===//
// Definition of Pattern and related types.
//===----------------------------------------------------------------------===//
// TODO(clattner): Move this out of this file when it is ready.
// TODO(clattner): Define this as a tagged union with proper sentinels.
typedef int PatternBenefit;
/// Pattern state is used by patterns that want to maintain state between their
/// match and rewrite phases. Patterns can define a pattern-specific subclass
/// of this.
class PatternState {
public:
virtual ~PatternState() {}
};
/// This is the type returned by a pattern match. The first field indicates the
/// benefit of the match, the second is a state token that can optionally be
/// produced by a pattern match to maintain state between the match and rewrite
/// phases.
typedef std::pair<PatternBenefit, std::unique_ptr<PatternState>>
PatternMatchResult;
class Pattern {
public:
// Return the benefit (the inverse of “cost”) of matching this pattern,
// if it is statically determinable. The result is an integer if known,
// a sentinel if dynamically computed, and another sentinel if the
// pattern can never be matched.
PatternBenefit getStaticBenefit() const { return staticBenefit; }
// Return the root node that this pattern matches. Patterns that can
// match multiple root types are instantiated once per root.
OperationName getRootKind() const { return rootKind; }
//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//
// Attempt to match against code rooted at the specified operation,
// which is the same operation code as getRootKind(). On success it
// returns the benefit of the match along with an (optional)
// pattern-specific state which is passed back into its rewrite
// function if this match is selected. On failure, this returns a
// sentinel indicating that it didn’t match.
virtual PatternMatchResult match(Operation *op) const = 0;
// Rewrite the IR rooted at the specified operation with the result of
// this pattern, generating any new operations with the specified
// builder. If an unexpected error is encountered (an internal
// compiler error), it is emitted through the normal MLIR diagnostic
// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
// TODO: Need a generic builder.
MLFuncBuilder &builder) const {
rewrite(op, builder);
}
// Rewrite the IR rooted at the specified operation with the result of
// this pattern, generating any new operations with the specified
// builder. If an unexpected error is encountered (an internal
// compiler error), it is emitted through the normal MLIR diagnostic
// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op,
// TODO: Need a generic builder.
MLFuncBuilder &builder) const {
llvm_unreachable("need to implement one of the rewrite functions!");
}
virtual ~Pattern();
//===--------------------------------------------------------------------===//
// Helper methods to simplify pattern implementations
//===--------------------------------------------------------------------===//
/// This method indicates that no match was found.
static PatternMatchResult matchFailure() {
// TODO: Use a proper sentinel / discriminated union instad of -1 magic
// number.
return {-1, std::unique_ptr<PatternState>()};
}
static PatternMatchResult matchSuccess(
PatternBenefit benefit,
std::unique_ptr<PatternState> state = std::unique_ptr<PatternState>()) {
return {benefit, std::move(state)};
}
protected:
Pattern(PatternBenefit staticBenefit, OperationName rootKind)
: staticBenefit(staticBenefit), rootKind(rootKind) {}
private:
const PatternBenefit staticBenefit;
const OperationName rootKind;
};
Pattern::~Pattern() {}
//===----------------------------------------------------------------------===//
// PatternMatcher class
//===----------------------------------------------------------------------===//
/// This class manages optimization an execution of a group of patterns, and
/// provides an API for finding the best match against a given node.
///
class PatternMatcher {
public:
/// Create a PatternMatch with the specified set of patterns. This takes
/// ownership of the patterns in question.
explicit PatternMatcher(ArrayRef<Pattern *> patterns)
: patterns(patterns.begin(), patterns.end()) {}
typedef std::pair<Pattern *, std::unique_ptr<PatternState>> MatchResult;
/// Find the highest benefit pattern available in the pattern set for the DAG
/// rooted at the specified node. This returns the pattern (and any state it
/// needs) if found, or null if there are no matches.
MatchResult findMatch(Operation *op);
~PatternMatcher() {
for (auto *p : patterns)
delete p;
}
private:
PatternMatcher(const PatternMatcher &) = delete;
void operator=(const PatternMatcher &) = delete;
std::vector<Pattern *> patterns;
};
/// Find the highest benefit pattern available in the pattern set for the DAG
/// rooted at the specified node. This returns the pattern if found, or null
/// if there are no matches.
auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
// TODO: This is a completely trivial implementation, expand this in the
// future.
// Keep track of the best match, the benefit of it, and any matcher specific
// state it is maintaining.
MatchResult bestMatch = {nullptr, nullptr};
// TODO: eliminate magic numbers.
PatternBenefit bestBenefit = -1;
for (auto *pattern : patterns) {
// Ignore patterns that are for the wrong root.
if (pattern->getRootKind() != op->getName())
continue;
// If we know the static cost of the pattern is worse than what we've
// already found then don't run it.
auto staticBenefit = pattern->getStaticBenefit();
if (staticBenefit < 0 || staticBenefit < bestBenefit)
continue;
// Check to see if this pattern matches this node.
auto result = pattern->match(op);
// TODO: magic numbers.
if (result.first < 0 || result.first < bestBenefit)
continue;
// Okay we found a match that is better than our previous one, remember it.
bestBenefit = result.first;
bestMatch = {pattern, std::move(result.second)};
}
// If we found any match, return it.
return bestMatch;
}
//===----------------------------------------------------------------------===//
// Definition of a few patterns for canonicalizing operations.
//===----------------------------------------------------------------------===//
namespace {
/// subi(x,x) -> 0
///
struct SimplifyXMinusX : public Pattern {
SimplifyXMinusX(MLIRContext *context)
// FIXME: rename getOperationName and add a proper one.
: Pattern(1, OperationName(SubIOp::getOperationName(), context)) {}
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
match(Operation *op) const override {
// TODO: Rename getAs -> dyn_cast, and add a cast<> method.
auto subi = op->getAs<SubIOp>();
assert(subi && "Matcher should have produced this");
if (subi->getOperand(0) == subi->getOperand(1))
return matchSuccess(1);
return matchFailure();
}
// Rewrite the IR rooted at the specified operation with the result of
// this pattern, generating any new operations with the specified
// builder. If an unexpected error is encountered (an internal
// compiler error), it is emitted through the normal MLIR diagnostic
// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, MLFuncBuilder &builder) const override {
// TODO: Rename getAs -> dyn_cast, and add a cast<> method.
auto subi = op->getAs<SubIOp>();
assert(subi && "Matcher should have produced this");
// TODO: Better "replace and remove" API on Pattern.
auto result =
builder.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
op->getResult(0)->replaceAllUsesWith(result->getResult());
cast<OperationStmt>(op)->eraseFromBlock();
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// The actual Canonicalizer Pass.
//===----------------------------------------------------------------------===//
// TODO: Canonicalize and unique all constant operations into the entry of the
// function.
namespace {
/// Canonicalize operations in functions.
struct Canonicalizer : public FunctionPass {
PassResult runOnCFGFunction(CFGFunction *f) override;
PassResult runOnMLFunction(MLFunction *f) override;
void simplifyFunction(std::vector<Operation *> &worklist,
MLFuncBuilder &builder);
};
} // end anonymous namespace
PassResult Canonicalizer::runOnCFGFunction(CFGFunction *f) {
// TODO: Add this.
return success();
}
PassResult Canonicalizer::runOnMLFunction(MLFunction *f) {
std::vector<Operation *> worklist;
worklist.reserve(64);
f->walk([&](OperationStmt *stmt) { worklist.push_back(stmt); });
MLFuncBuilder builder(f);
simplifyFunction(worklist, builder);
return success();
}
// TODO: This should work on both ML and CFG functions.
void Canonicalizer::simplifyFunction(std::vector<Operation *> &worklist,
MLFuncBuilder &builder) {
// TODO: Instead of a hard coded list of patterns, ask the registered dialects
// for their canonicalization patterns.
PatternMatcher matcher({new SimplifyXMinusX(builder.getContext())});
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
// TODO: If no side effects, and operation has no users, then it is
// trivially dead - remove it.
// TODO: Call the constant folding hook on this operation, and canonicalize
// constants into the entry node.
// Check to see if we have any patterns that match this node.
auto match = matcher.findMatch(op);
if (!match.first)
continue;
// TODO: Need to be a bit trickier to make sure new instructions get into
// the worklist.
match.first->rewrite(op, std::move(match.second), builder);
}
}
/// Create a Canonicalizer pass.
FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); }
|