diff options
Diffstat (limited to 'mlir/lib/IR/PatternMatch.cpp')
| -rw-r--r-- | mlir/lib/IR/PatternMatch.cpp | 74 |
1 files changed, 25 insertions, 49 deletions
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index bc746351352..7c86a1e9995 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -31,18 +31,6 @@ unsigned short PatternBenefit::getBenefit() const { return representation; } -bool PatternBenefit::operator==(const PatternBenefit& other) { - if (isImpossibleToMatch()) - return other.isImpossibleToMatch(); - if (other.isImpossibleToMatch()) - return false; - return getBenefit() == other.getBenefit(); -} - -bool PatternBenefit::operator!=(const PatternBenefit& other) { - return !(*this == other); -} - //===----------------------------------------------------------------------===// // Pattern implementation //===----------------------------------------------------------------------===// @@ -65,7 +53,12 @@ void RewritePattern::rewrite(Instruction *op, } void RewritePattern::rewrite(Instruction *op, PatternRewriter &rewriter) const { - llvm_unreachable("need to implement one of the rewrite functions!"); + llvm_unreachable("need to implement either matchAndRewrite or one of the " + "rewrite functions!"); +} + +PatternMatchResult RewritePattern::match(Instruction *op) const { + llvm_unreachable("need to implement either match or matchAndRewrite!"); } PatternRewriter::~PatternRewriter() { @@ -131,45 +124,28 @@ void PatternRewriter::updatedRootInPlace( // PatternMatcher implementation //===----------------------------------------------------------------------===// -/// Find the highest benefit pattern available in the pattern set for the DAG -/// rooted at the specified node. This returns the pattern if found, or null -/// if there are no matches. -auto PatternMatcher::findMatch(Instruction *op) -> MatchResult { - // TODO: This is a completely trivial implementation, expand this in the - // future. - - // Keep track of the best match, the benefit of it, and any matcher specific - // state it is maintaining. - MatchResult bestMatch = {nullptr, nullptr}; - Optional<PatternBenefit> bestBenefit; +RewritePatternMatcher::RewritePatternMatcher( + OwningRewritePatternList &&patterns, PatternRewriter &rewriter) + : patterns(std::move(patterns)), rewriter(rewriter) { + // Sort the patterns by benefit to simplify the matching logic. + std::stable_sort(this->patterns.begin(), this->patterns.end(), + [](const std::unique_ptr<RewritePattern> &l, + const std::unique_ptr<RewritePattern> &r) { + return r->getBenefit() < l->getBenefit(); + }); +} +/// Try to match the given operation to a pattern and rewrite it. +void RewritePatternMatcher::matchAndRewrite(Instruction *op) { for (auto &pattern : patterns) { - // Ignore patterns that are for the wrong root. - if (pattern->getRootKind() != op->getName()) + // Ignore patterns that are for the wrong root or are impossible to match. + if (pattern->getRootKind() != op->getName() || + pattern->getBenefit().isImpossibleToMatch()) continue; - auto benefit = pattern->getBenefit(); - if (benefit.isImpossibleToMatch()) - continue; - - // If the benefit of the pattern is worse than what we've already found then - // don't run it. - if (bestBenefit.hasValue() && - benefit.getBenefit() < bestBenefit.getValue().getBenefit()) - continue; - - // Check to see if this pattern matches this node. - auto result = pattern->match(op); - - // If this pattern failed to match, ignore it. - if (!result) - continue; - - // Okay we found a match that is better than our previous one, remember it. - bestBenefit = benefit; - bestMatch = {pattern.get(), std::move(result.getValue())}; + // Try to match and rewrite this pattern. The patterns are sorted by + // benefit, so if we match we can immediately rewrite and return. + if (pattern->matchAndRewrite(op, rewriter)) + return; } - - // If we found any match, return it. - return bestMatch; } |

