summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/PatternMatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/PatternMatch.cpp')
-rw-r--r--mlir/lib/IR/PatternMatch.cpp74
1 files changed, 25 insertions, 49 deletions
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index bc746351352..7c86a1e9995 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -31,18 +31,6 @@ unsigned short PatternBenefit::getBenefit() const {
return representation;
}
-bool PatternBenefit::operator==(const PatternBenefit& other) {
- if (isImpossibleToMatch())
- return other.isImpossibleToMatch();
- if (other.isImpossibleToMatch())
- return false;
- return getBenefit() == other.getBenefit();
-}
-
-bool PatternBenefit::operator!=(const PatternBenefit& other) {
- return !(*this == other);
-}
-
//===----------------------------------------------------------------------===//
// Pattern implementation
//===----------------------------------------------------------------------===//
@@ -65,7 +53,12 @@ void RewritePattern::rewrite(Instruction *op,
}
void RewritePattern::rewrite(Instruction *op, PatternRewriter &rewriter) const {
- llvm_unreachable("need to implement one of the rewrite functions!");
+ llvm_unreachable("need to implement either matchAndRewrite or one of the "
+ "rewrite functions!");
+}
+
+PatternMatchResult RewritePattern::match(Instruction *op) const {
+ llvm_unreachable("need to implement either match or matchAndRewrite!");
}
PatternRewriter::~PatternRewriter() {
@@ -131,45 +124,28 @@ void PatternRewriter::updatedRootInPlace(
// PatternMatcher implementation
//===----------------------------------------------------------------------===//
-/// Find the highest benefit pattern available in the pattern set for the DAG
-/// rooted at the specified node. This returns the pattern if found, or null
-/// if there are no matches.
-auto PatternMatcher::findMatch(Instruction *op) -> MatchResult {
- // TODO: This is a completely trivial implementation, expand this in the
- // future.
-
- // Keep track of the best match, the benefit of it, and any matcher specific
- // state it is maintaining.
- MatchResult bestMatch = {nullptr, nullptr};
- Optional<PatternBenefit> bestBenefit;
+RewritePatternMatcher::RewritePatternMatcher(
+ OwningRewritePatternList &&patterns, PatternRewriter &rewriter)
+ : patterns(std::move(patterns)), rewriter(rewriter) {
+ // Sort the patterns by benefit to simplify the matching logic.
+ std::stable_sort(this->patterns.begin(), this->patterns.end(),
+ [](const std::unique_ptr<RewritePattern> &l,
+ const std::unique_ptr<RewritePattern> &r) {
+ return r->getBenefit() < l->getBenefit();
+ });
+}
+/// Try to match the given operation to a pattern and rewrite it.
+void RewritePatternMatcher::matchAndRewrite(Instruction *op) {
for (auto &pattern : patterns) {
- // Ignore patterns that are for the wrong root.
- if (pattern->getRootKind() != op->getName())
+ // Ignore patterns that are for the wrong root or are impossible to match.
+ if (pattern->getRootKind() != op->getName() ||
+ pattern->getBenefit().isImpossibleToMatch())
continue;
- auto benefit = pattern->getBenefit();
- if (benefit.isImpossibleToMatch())
- continue;
-
- // If the benefit of the pattern is worse than what we've already found then
- // don't run it.
- if (bestBenefit.hasValue() &&
- benefit.getBenefit() < bestBenefit.getValue().getBenefit())
- continue;
-
- // Check to see if this pattern matches this node.
- auto result = pattern->match(op);
-
- // If this pattern failed to match, ignore it.
- if (!result)
- continue;
-
- // Okay we found a match that is better than our previous one, remember it.
- bestBenefit = benefit;
- bestMatch = {pattern.get(), std::move(result.getValue())};
+ // Try to match and rewrite this pattern. The patterns are sorted by
+ // benefit, so if we match we can immediately rewrite and return.
+ if (pattern->matchAndRewrite(op, rewriter))
+ return;
}
-
- // If we found any match, return it.
- return bestMatch;
}
OpenPOWER on IntegriCloud