summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h98
-rw-r--r--mlir/lib/IR/PatternMatch.cpp74
-rw-r--r--mlir/lib/StandardOps/Ops.cpp76
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp18
4 files changed, 120 insertions, 146 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 5d5b1c3a1ff..a65c8dd7c8a 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -44,17 +44,19 @@ public:
PatternBenefit &operator=(const PatternBenefit &) = default;
static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
-
- bool isImpossibleToMatch() const {
- return representation == ImpossibleToMatchSentinel;
- }
+ bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
/// If the corresponding pattern can match, return its benefit. If the
// corresponding pattern isImpossibleToMatch() then this aborts.
unsigned short getBenefit() const;
- inline bool operator==(const PatternBenefit& other);
- inline bool operator!=(const PatternBenefit& other);
+ bool operator==(const PatternBenefit &rhs) const {
+ return representation == rhs.representation;
+ }
+ bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
+ bool operator<(const PatternBenefit &rhs) const {
+ return representation < rhs.representation;
+ }
private:
PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
@@ -105,9 +107,8 @@ public:
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
- /// returns a None value. On success it a (possibly null) pattern-specific
- /// state wrapped in a Some. This state is passed back into its rewrite
- /// function if this match is selected.
+ /// returns a None value. On success it returns a (possibly null)
+ /// pattern-specific state wrapped in an Optional.
virtual PatternMatchResult match(Instruction *op) const = 0;
virtual ~Pattern() {}
@@ -138,8 +139,14 @@ private:
};
/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// After a RewritePattern is matched, its replacement is performed by invoking
-/// the "rewrite" method that the instance implements.
+/// There are two possible usages of this class:
+/// * Multi-step RewritePattern with "match" and "rewrite"
+/// - By overloading the "match" and "rewrite" functions, the user can
+/// separate the concerns of matching and rewriting.
+/// * Single-step RewritePattern with "matchAndRewrite"
+/// - By overloading the "matchAndRewrite" function, the user can perform
+/// the rewrite in the same call as the match. This removes the need for
+/// any PatternState.
///
class RewritePattern : public Pattern {
public:
@@ -158,6 +165,25 @@ public:
/// hooks and the IR is left in a valid state.
virtual void rewrite(Instruction *op, PatternRewriter &rewriter) const;
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). On failure, this
+ /// returns a None value. On success, it returns a (possibly null)
+ /// pattern-specific state wrapped in an Optional. This state is passed back
+ /// into the rewrite function if this match is selected.
+ PatternMatchResult match(Instruction *op) const override;
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). If successful, this
+ /// function will automatically perform the rewrite.
+ virtual PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const {
+ if (auto matchResult = match(op)) {
+ rewrite(op, std::move(*matchResult), rewriter);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
@@ -289,51 +315,37 @@ private:
};
//===----------------------------------------------------------------------===//
-// PatternMatcher class
+// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
/// This is a vector that owns the patterns inside of it.
-using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
+using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
-/// This class manages optimization and execution of a group of patterns,
-/// providing an API for finding the best match against a given node.
+/// This class manages optimization and execution of a group of rewrite
+/// patterns, providing an API for finding and applying, the best match against
+/// a given node.
///
-class PatternMatcher {
+class RewritePatternMatcher {
public:
- /// Create a PatternMatch with the specified set of patterns.
- explicit PatternMatcher(OwningPatternList &&patterns)
- : patterns(std::move(patterns)) {}
-
- // Support matching from subclasses of Pattern.
- template <typename T>
- explicit PatternMatcher(std::vector<std::unique_ptr<T>> &&patternSubclasses) {
- patterns.reserve(patternSubclasses.size());
- for (auto &&elt : patternSubclasses)
- patterns.emplace_back(std::move(elt));
- }
-
- using MatchResult = std::pair<Pattern *, std::unique_ptr<PatternState>>;
+ /// Create a RewritePatternMatcher with the specified set of patterns and
+ /// rewriter.
+ explicit RewritePatternMatcher(OwningRewritePatternList &&patterns,
+ PatternRewriter &rewriter);
- /// Find the highest benefit pattern available in the pattern set for the DAG
- /// rooted at the specified node. This returns the pattern (and any state it
- /// needs) if found, or null if there are no matches.
- MatchResult findMatch(Instruction *op);
+ /// Try to match the given operation to a pattern and rewrite it.
+ void matchAndRewrite(Instruction *op);
private:
- PatternMatcher(const PatternMatcher &) = delete;
- void operator=(const PatternMatcher &) = delete;
+ RewritePatternMatcher(const RewritePatternMatcher &) = delete;
+ void operator=(const RewritePatternMatcher &) = delete;
/// The group of patterns that are matched for optimization through this
/// matcher.
- OwningPatternList patterns;
-};
+ OwningRewritePatternList patterns;
-//===----------------------------------------------------------------------===//
-// Pattern-driven rewriters
-//===----------------------------------------------------------------------===//
-
-/// This is a vector that owns the patterns inside of it.
-using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+ /// The rewriter used when applying matched patterns.
+ PatternRewriter &rewriter;
+};
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index bc746351352..7c86a1e9995 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -31,18 +31,6 @@ unsigned short PatternBenefit::getBenefit() const {
return representation;
}
-bool PatternBenefit::operator==(const PatternBenefit& other) {
- if (isImpossibleToMatch())
- return other.isImpossibleToMatch();
- if (other.isImpossibleToMatch())
- return false;
- return getBenefit() == other.getBenefit();
-}
-
-bool PatternBenefit::operator!=(const PatternBenefit& other) {
- return !(*this == other);
-}
-
//===----------------------------------------------------------------------===//
// Pattern implementation
//===----------------------------------------------------------------------===//
@@ -65,7 +53,12 @@ void RewritePattern::rewrite(Instruction *op,
}
void RewritePattern::rewrite(Instruction *op, PatternRewriter &rewriter) const {
- llvm_unreachable("need to implement one of the rewrite functions!");
+ llvm_unreachable("need to implement either matchAndRewrite or one of the "
+ "rewrite functions!");
+}
+
+PatternMatchResult RewritePattern::match(Instruction *op) const {
+ llvm_unreachable("need to implement either match or matchAndRewrite!");
}
PatternRewriter::~PatternRewriter() {
@@ -131,45 +124,28 @@ void PatternRewriter::updatedRootInPlace(
// PatternMatcher implementation
//===----------------------------------------------------------------------===//
-/// Find the highest benefit pattern available in the pattern set for the DAG
-/// rooted at the specified node. This returns the pattern if found, or null
-/// if there are no matches.
-auto PatternMatcher::findMatch(Instruction *op) -> MatchResult {
- // TODO: This is a completely trivial implementation, expand this in the
- // future.
-
- // Keep track of the best match, the benefit of it, and any matcher specific
- // state it is maintaining.
- MatchResult bestMatch = {nullptr, nullptr};
- Optional<PatternBenefit> bestBenefit;
+RewritePatternMatcher::RewritePatternMatcher(
+ OwningRewritePatternList &&patterns, PatternRewriter &rewriter)
+ : patterns(std::move(patterns)), rewriter(rewriter) {
+ // Sort the patterns by benefit to simplify the matching logic.
+ std::stable_sort(this->patterns.begin(), this->patterns.end(),
+ [](const std::unique_ptr<RewritePattern> &l,
+ const std::unique_ptr<RewritePattern> &r) {
+ return r->getBenefit() < l->getBenefit();
+ });
+}
+/// Try to match the given operation to a pattern and rewrite it.
+void RewritePatternMatcher::matchAndRewrite(Instruction *op) {
for (auto &pattern : patterns) {
- // Ignore patterns that are for the wrong root.
- if (pattern->getRootKind() != op->getName())
+ // Ignore patterns that are for the wrong root or are impossible to match.
+ if (pattern->getRootKind() != op->getName() ||
+ pattern->getBenefit().isImpossibleToMatch())
continue;
- auto benefit = pattern->getBenefit();
- if (benefit.isImpossibleToMatch())
- continue;
-
- // If the benefit of the pattern is worse than what we've already found then
- // don't run it.
- if (bestBenefit.hasValue() &&
- benefit.getBenefit() < bestBenefit.getValue().getBenefit())
- continue;
-
- // Check to see if this pattern matches this node.
- auto result = pattern->match(op);
-
- // If this pattern failed to match, ignore it.
- if (!result)
- continue;
-
- // Okay we found a match that is better than our previous one, remember it.
- bestBenefit = benefit;
- bestMatch = {pattern.get(), std::move(result.getValue())};
+ // Try to match and rewrite this pattern. The patterns are sorted by
+ // benefit, so if we match we can immediately rewrite and return.
+ if (pattern->matchAndRewrite(op, rewriter))
+ return;
}
-
- // If we found any match, return it.
- return bestMatch;
}
diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp
index a14f3a24e82..50db72faea1 100644
--- a/mlir/lib/StandardOps/Ops.cpp
+++ b/mlir/lib/StandardOps/Ops.cpp
@@ -356,15 +356,16 @@ struct SimplifyDeadAlloc : public RewritePattern {
SimplifyDeadAlloc(MLIRContext *context)
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
+ // Check if the alloc'ed value has any uses.
auto alloc = op->cast<AllocOp>();
- // Check if the alloc'ed value has no uses.
- return alloc->use_empty() ? matchSuccess() : matchFailure();
- }
+ if (!alloc->use_empty())
+ return matchFailure();
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
- // Erase the alloc operation.
+ // If it doesn't, we can eliminate it.
op->erase();
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -486,29 +487,24 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
: RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto indirectCall = op->cast<CallIndirectOp>();
// Check that the callee is a constant operation.
- Value *callee = indirectCall->getCallee();
- Instruction *calleeInst = callee->getDefiningInst();
- if (!calleeInst || !calleeInst->isa<ConstantOp>())
+ Attribute callee;
+ if (!matchPattern(indirectCall->getCallee(), m_Constant(&callee)))
return matchFailure();
// Check that the constant callee is a function.
- if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>())
- return matchSuccess();
- return matchFailure();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
- auto indirectCall = op->cast<CallIndirectOp>();
- auto calleeOp =
- indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>();
+ FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>();
+ if (!calledFn)
+ return matchFailure();
// Replace with a direct call.
- Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue();
SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands);
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -802,15 +798,14 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
SimplifyConstCondBranchPred(MLIRContext *context)
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto condbr = op->cast<CondBranchOp>();
- if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
- return matchSuccess();
- return matchFailure();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
- auto condbr = op->cast<CondBranchOp>();
+ // Check that the condition is a constant.
+ if (!matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
+ return matchFailure();
+
Block *foldedDest;
SmallVector<Value *, 4> branchArgs;
@@ -828,6 +823,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
}
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -1094,7 +1090,8 @@ struct SimplifyDeadDealloc : public RewritePattern {
SimplifyDeadDealloc(MLIRContext *context)
: RewritePattern(DeallocOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Instruction *op) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto dealloc = op->cast<DeallocOp>();
// Check that the memref operand's defining instruction is an AllocOp.
@@ -1107,12 +1104,10 @@ struct SimplifyDeadDealloc : public RewritePattern {
for (auto &use : memref->getUses())
if (!use.getOwner()->isa<DeallocOp>())
return matchFailure();
- return matchSuccess();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
// Erase the dealloc operation.
op->erase();
+ return matchSuccess();
}
};
} // end anonymous namespace.
@@ -1991,21 +1986,16 @@ namespace {
///
struct SimplifyXMinusX : public RewritePattern {
SimplifyXMinusX(MLIRContext *context)
- : RewritePattern(SubIOp::getOperationName(), 1, context) {}
+ : RewritePattern(SubIOp::getOperationName(), 10, context) {}
- PatternMatchResult match(Instruction *op) const override {
- auto subi = op->cast<SubIOp>();
- if (subi->getOperand(0) == subi->getOperand(1))
- return matchSuccess();
-
- return matchFailure();
- }
- void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
+ PatternMatchResult matchAndRewrite(Instruction *op,
+ PatternRewriter &rewriter) const override {
auto subi = op->cast<SubIOp>();
- auto result =
- rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
+ if (subi->getOperand(0) != subi->getOperand(1))
+ return matchFailure();
- rewriter.replaceOp(op, {result});
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi->getType());
+ return matchSuccess();
}
};
} // end anonymous namespace.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2a0238cbd96..a2d6f392c32 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -34,7 +34,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(Function *fn,
OwningRewritePatternList &&patterns)
- : PatternRewriter(fn->getContext()), matcher(std::move(patterns)),
+ : PatternRewriter(fn->getContext()), matcher(std::move(patterns), *this),
builder(fn) {
worklist.reserve(64);
@@ -122,7 +122,7 @@ private:
}
/// The low-level pattern matcher.
- PatternMatcher matcher;
+ RewritePatternMatcher matcher;
/// This builder is used to create new operations.
FuncBuilder builder;
@@ -284,17 +284,13 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
continue;
}
- // Check to see if we have any patterns that match this node.
- auto match = matcher.findMatch(op);
- if (!match.first)
- continue;
-
// Make sure that any new operations are inserted at this point.
builder.setInsertionPoint(op);
- // We know that any pattern that matched is RewritePattern because we
- // initialized the matcher with RewritePatterns.
- auto *rewritePattern = static_cast<RewritePattern *>(match.first);
- rewritePattern->rewrite(op, std::move(match.second), *this);
+
+ // Try to match one of the canonicalization patterns. The rewriter is
+ // automatically notified of any necessary changes, so there is nothing else
+ // to do here.
+ matcher.matchAndRewrite(op);
}
uniquedConstants.clear();
OpenPOWER on IntegriCloud