diff options
| -rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 98 | ||||
| -rw-r--r-- | mlir/lib/IR/PatternMatch.cpp | 74 | ||||
| -rw-r--r-- | mlir/lib/StandardOps/Ops.cpp | 76 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 18 |
4 files changed, 120 insertions, 146 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 5d5b1c3a1ff..a65c8dd7c8a 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -44,17 +44,19 @@ public: PatternBenefit &operator=(const PatternBenefit &) = default; static PatternBenefit impossibleToMatch() { return PatternBenefit(); } - - bool isImpossibleToMatch() const { - return representation == ImpossibleToMatchSentinel; - } + bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } /// If the corresponding pattern can match, return its benefit. If the // corresponding pattern isImpossibleToMatch() then this aborts. unsigned short getBenefit() const; - inline bool operator==(const PatternBenefit& other); - inline bool operator!=(const PatternBenefit& other); + bool operator==(const PatternBenefit &rhs) const { + return representation == rhs.representation; + } + bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } + bool operator<(const PatternBenefit &rhs) const { + return representation < rhs.representation; + } private: PatternBenefit() : representation(ImpossibleToMatchSentinel) {} @@ -105,9 +107,8 @@ public: /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). On failure, this - /// returns a None value. On success it a (possibly null) pattern-specific - /// state wrapped in a Some. This state is passed back into its rewrite - /// function if this match is selected. + /// returns a None value. On success it returns a (possibly null) + /// pattern-specific state wrapped in an Optional. virtual PatternMatchResult match(Instruction *op) const = 0; virtual ~Pattern() {} @@ -138,8 +139,14 @@ private: }; /// RewritePattern is the common base class for all DAG to DAG replacements. -/// After a RewritePattern is matched, its replacement is performed by invoking -/// the "rewrite" method that the instance implements. +/// There are two possible usages of this class: +/// * Multi-step RewritePattern with "match" and "rewrite" +/// - By overloading the "match" and "rewrite" functions, the user can +/// separate the concerns of matching and rewriting. +/// * Single-step RewritePattern with "matchAndRewrite" +/// - By overloading the "matchAndRewrite" function, the user can perform +/// the rewrite in the same call as the match. This removes the need for +/// any PatternState. /// class RewritePattern : public Pattern { public: @@ -158,6 +165,25 @@ public: /// hooks and the IR is left in a valid state. virtual void rewrite(Instruction *op, PatternRewriter &rewriter) const; + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). On failure, this + /// returns a None value. On success, it returns a (possibly null) + /// pattern-specific state wrapped in an Optional. This state is passed back + /// into the rewrite function if this match is selected. + PatternMatchResult match(Instruction *op) const override; + + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). If successful, this + /// function will automatically perform the rewrite. + virtual PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const { + if (auto matchResult = match(op)) { + rewrite(op, std::move(*matchResult), rewriter); + return matchSuccess(); + } + return matchFailure(); + } + protected: /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. @@ -289,51 +315,37 @@ private: }; //===----------------------------------------------------------------------===// -// PatternMatcher class +// Pattern-driven rewriters //===----------------------------------------------------------------------===// /// This is a vector that owns the patterns inside of it. -using OwningPatternList = std::vector<std::unique_ptr<Pattern>>; +using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; -/// This class manages optimization and execution of a group of patterns, -/// providing an API for finding the best match against a given node. +/// This class manages optimization and execution of a group of rewrite +/// patterns, providing an API for finding and applying, the best match against +/// a given node. /// -class PatternMatcher { +class RewritePatternMatcher { public: - /// Create a PatternMatch with the specified set of patterns. - explicit PatternMatcher(OwningPatternList &&patterns) - : patterns(std::move(patterns)) {} - - // Support matching from subclasses of Pattern. - template <typename T> - explicit PatternMatcher(std::vector<std::unique_ptr<T>> &&patternSubclasses) { - patterns.reserve(patternSubclasses.size()); - for (auto &&elt : patternSubclasses) - patterns.emplace_back(std::move(elt)); - } - - using MatchResult = std::pair<Pattern *, std::unique_ptr<PatternState>>; + /// Create a RewritePatternMatcher with the specified set of patterns and + /// rewriter. + explicit RewritePatternMatcher(OwningRewritePatternList &&patterns, + PatternRewriter &rewriter); - /// Find the highest benefit pattern available in the pattern set for the DAG - /// rooted at the specified node. This returns the pattern (and any state it - /// needs) if found, or null if there are no matches. - MatchResult findMatch(Instruction *op); + /// Try to match the given operation to a pattern and rewrite it. + void matchAndRewrite(Instruction *op); private: - PatternMatcher(const PatternMatcher &) = delete; - void operator=(const PatternMatcher &) = delete; + RewritePatternMatcher(const RewritePatternMatcher &) = delete; + void operator=(const RewritePatternMatcher &) = delete; /// The group of patterns that are matched for optimization through this /// matcher. - OwningPatternList patterns; -}; + OwningRewritePatternList patterns; -//===----------------------------------------------------------------------===// -// Pattern-driven rewriters -//===----------------------------------------------------------------------===// - -/// This is a vector that owns the patterns inside of it. -using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>; + /// The rewriter used when applying matched patterns. + PatternRewriter &rewriter; +}; /// Rewrite the specified function by repeatedly applying the highest benefit /// patterns in a greedy work-list driven manner. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index bc746351352..7c86a1e9995 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -31,18 +31,6 @@ unsigned short PatternBenefit::getBenefit() const { return representation; } -bool PatternBenefit::operator==(const PatternBenefit& other) { - if (isImpossibleToMatch()) - return other.isImpossibleToMatch(); - if (other.isImpossibleToMatch()) - return false; - return getBenefit() == other.getBenefit(); -} - -bool PatternBenefit::operator!=(const PatternBenefit& other) { - return !(*this == other); -} - //===----------------------------------------------------------------------===// // Pattern implementation //===----------------------------------------------------------------------===// @@ -65,7 +53,12 @@ void RewritePattern::rewrite(Instruction *op, } void RewritePattern::rewrite(Instruction *op, PatternRewriter &rewriter) const { - llvm_unreachable("need to implement one of the rewrite functions!"); + llvm_unreachable("need to implement either matchAndRewrite or one of the " + "rewrite functions!"); +} + +PatternMatchResult RewritePattern::match(Instruction *op) const { + llvm_unreachable("need to implement either match or matchAndRewrite!"); } PatternRewriter::~PatternRewriter() { @@ -131,45 +124,28 @@ void PatternRewriter::updatedRootInPlace( // PatternMatcher implementation //===----------------------------------------------------------------------===// -/// Find the highest benefit pattern available in the pattern set for the DAG -/// rooted at the specified node. This returns the pattern if found, or null -/// if there are no matches. -auto PatternMatcher::findMatch(Instruction *op) -> MatchResult { - // TODO: This is a completely trivial implementation, expand this in the - // future. - - // Keep track of the best match, the benefit of it, and any matcher specific - // state it is maintaining. - MatchResult bestMatch = {nullptr, nullptr}; - Optional<PatternBenefit> bestBenefit; +RewritePatternMatcher::RewritePatternMatcher( + OwningRewritePatternList &&patterns, PatternRewriter &rewriter) + : patterns(std::move(patterns)), rewriter(rewriter) { + // Sort the patterns by benefit to simplify the matching logic. + std::stable_sort(this->patterns.begin(), this->patterns.end(), + [](const std::unique_ptr<RewritePattern> &l, + const std::unique_ptr<RewritePattern> &r) { + return r->getBenefit() < l->getBenefit(); + }); +} +/// Try to match the given operation to a pattern and rewrite it. +void RewritePatternMatcher::matchAndRewrite(Instruction *op) { for (auto &pattern : patterns) { - // Ignore patterns that are for the wrong root. - if (pattern->getRootKind() != op->getName()) + // Ignore patterns that are for the wrong root or are impossible to match. + if (pattern->getRootKind() != op->getName() || + pattern->getBenefit().isImpossibleToMatch()) continue; - auto benefit = pattern->getBenefit(); - if (benefit.isImpossibleToMatch()) - continue; - - // If the benefit of the pattern is worse than what we've already found then - // don't run it. - if (bestBenefit.hasValue() && - benefit.getBenefit() < bestBenefit.getValue().getBenefit()) - continue; - - // Check to see if this pattern matches this node. - auto result = pattern->match(op); - - // If this pattern failed to match, ignore it. - if (!result) - continue; - - // Okay we found a match that is better than our previous one, remember it. - bestBenefit = benefit; - bestMatch = {pattern.get(), std::move(result.getValue())}; + // Try to match and rewrite this pattern. The patterns are sorted by + // benefit, so if we match we can immediately rewrite and return. + if (pattern->matchAndRewrite(op, rewriter)) + return; } - - // If we found any match, return it. - return bestMatch; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index a14f3a24e82..50db72faea1 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -356,15 +356,16 @@ struct SimplifyDeadAlloc : public RewritePattern { SimplifyDeadAlloc(MLIRContext *context) : RewritePattern(AllocOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { + // Check if the alloc'ed value has any uses. auto alloc = op->cast<AllocOp>(); - // Check if the alloc'ed value has no uses. - return alloc->use_empty() ? matchSuccess() : matchFailure(); - } + if (!alloc->use_empty()) + return matchFailure(); - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - // Erase the alloc operation. + // If it doesn't, we can eliminate it. op->erase(); + return matchSuccess(); } }; } // end anonymous namespace. @@ -486,29 +487,24 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { SimplifyIndirectCallWithKnownCallee(MLIRContext *context) : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto indirectCall = op->cast<CallIndirectOp>(); // Check that the callee is a constant operation. - Value *callee = indirectCall->getCallee(); - Instruction *calleeInst = callee->getDefiningInst(); - if (!calleeInst || !calleeInst->isa<ConstantOp>()) + Attribute callee; + if (!matchPattern(indirectCall->getCallee(), m_Constant(&callee))) return matchFailure(); // Check that the constant callee is a function. - if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>()) - return matchSuccess(); - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto indirectCall = op->cast<CallIndirectOp>(); - auto calleeOp = - indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>(); + FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>(); + if (!calledFn) + return matchFailure(); // Replace with a direct call. - Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue(); SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands()); - rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands); + rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands); + return matchSuccess(); } }; } // end anonymous namespace. @@ -802,15 +798,14 @@ struct SimplifyConstCondBranchPred : public RewritePattern { SimplifyConstCondBranchPred(MLIRContext *context) : RewritePattern(CondBranchOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto condbr = op->cast<CondBranchOp>(); - if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>())) - return matchSuccess(); - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { - auto condbr = op->cast<CondBranchOp>(); + // Check that the condition is a constant. + if (!matchPattern(condbr->getCondition(), m_Op<ConstantOp>())) + return matchFailure(); + Block *foldedDest; SmallVector<Value *, 4> branchArgs; @@ -828,6 +823,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { } rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs); + return matchSuccess(); } }; } // end anonymous namespace. @@ -1094,7 +1090,8 @@ struct SimplifyDeadDealloc : public RewritePattern { SimplifyDeadDealloc(MLIRContext *context) : RewritePattern(DeallocOp::getOperationName(), 1, context) {} - PatternMatchResult match(Instruction *op) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto dealloc = op->cast<DeallocOp>(); // Check that the memref operand's defining instruction is an AllocOp. @@ -1107,12 +1104,10 @@ struct SimplifyDeadDealloc : public RewritePattern { for (auto &use : memref->getUses()) if (!use.getOwner()->isa<DeallocOp>()) return matchFailure(); - return matchSuccess(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { // Erase the dealloc operation. op->erase(); + return matchSuccess(); } }; } // end anonymous namespace. @@ -1991,21 +1986,16 @@ namespace { /// struct SimplifyXMinusX : public RewritePattern { SimplifyXMinusX(MLIRContext *context) - : RewritePattern(SubIOp::getOperationName(), 1, context) {} + : RewritePattern(SubIOp::getOperationName(), 10, context) {} - PatternMatchResult match(Instruction *op) const override { - auto subi = op->cast<SubIOp>(); - if (subi->getOperand(0) == subi->getOperand(1)) - return matchSuccess(); - - return matchFailure(); - } - void rewrite(Instruction *op, PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Instruction *op, + PatternRewriter &rewriter) const override { auto subi = op->cast<SubIOp>(); - auto result = - rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType()); + if (subi->getOperand(0) != subi->getOperand(1)) + return matchFailure(); - rewriter.replaceOp(op, {result}); + rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi->getType()); + return matchSuccess(); } }; } // end anonymous namespace. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 2a0238cbd96..a2d6f392c32 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -34,7 +34,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(Function *fn, OwningRewritePatternList &&patterns) - : PatternRewriter(fn->getContext()), matcher(std::move(patterns)), + : PatternRewriter(fn->getContext()), matcher(std::move(patterns), *this), builder(fn) { worklist.reserve(64); @@ -122,7 +122,7 @@ private: } /// The low-level pattern matcher. - PatternMatcher matcher; + RewritePatternMatcher matcher; /// This builder is used to create new operations. FuncBuilder builder; @@ -284,17 +284,13 @@ void GreedyPatternRewriteDriver::simplifyFunction() { continue; } - // Check to see if we have any patterns that match this node. - auto match = matcher.findMatch(op); - if (!match.first) - continue; - // Make sure that any new operations are inserted at this point. builder.setInsertionPoint(op); - // We know that any pattern that matched is RewritePattern because we - // initialized the matcher with RewritePatterns. - auto *rewritePattern = static_cast<RewritePattern *>(match.first); - rewritePattern->rewrite(op, std::move(match.second), *this); + + // Try to match one of the canonicalization patterns. The rewriter is + // automatically notified of any necessary changes, so there is nothing else + // to do here. + matcher.matchAndRewrite(op); } uniquedConstants.clear(); |

