summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/Matchers.h20
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp33
2 files changed, 30 insertions, 23 deletions
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index a464612da34..aba63d6e221 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -111,16 +111,24 @@ struct constant_int_op_binder {
}
};
-// The matcher that matches a given target constant scalar / vector splat /
-// tensor splat integer value.
+/// The matcher that matches a given target constant scalar / vector splat /
+/// tensor splat integer value.
template <int64_t TargetValue> struct constant_int_value_matcher {
bool match(Operation *op) {
APInt value;
-
return constant_int_op_binder(&value).match(op) && TargetValue == value;
}
};
+/// The matcher that matches anything except the given target constant scalar /
+/// vector splat / tensor splat integer value.
+template <int64_t TargetNotValue> struct constant_int_not_value_matcher {
+ bool match(Operation *op) {
+ APInt value;
+ return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
+ }
+};
+
/// The matcher that matches a certain kind of op.
template <typename OpClass> struct op_matcher {
bool match(Operation *op) { return isa<OpClass>(op); }
@@ -172,6 +180,12 @@ inline detail::constant_int_value_matcher<0> m_Zero() {
return detail::constant_int_value_matcher<0>();
}
+/// Matches a constant scalar / vector splat / tensor splat integer that is any
+/// non-zero value.
+inline detail::constant_int_not_value_matcher<0> m_NonZero() {
+ return detail::constant_int_not_value_matcher<0>();
+}
+
} // end namespace mlir
#endif // MLIR_MATCHERS_H
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 5a452c5242a..161a6c409c7 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -1070,27 +1070,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
- // Check that the condition is a constant.
- if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
- return matchFailure();
-
- Block *foldedDest;
- SmallVector<Value *, 4> branchArgs;
-
- // If the condition is known to evaluate to false we fold to a branch to the
- // false destination. Otherwise, we fold to a branch to the true
- // destination.
- if (matchPattern(condbr.getCondition(), m_Zero())) {
- foldedDest = condbr.getFalseDest();
- branchArgs.assign(condbr.false_operand_begin(),
- condbr.false_operand_end());
- } else {
- foldedDest = condbr.getTrueDest();
- branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
+ if (matchPattern(condbr.getCondition(), m_NonZero())) {
+ // True branch taken.
+ rewriter.replaceOpWithNewOp<BranchOp>(
+ condbr, condbr.getTrueDest(),
+ llvm::to_vector<4>(condbr.getTrueOperands()));
+ return matchSuccess();
+ } else if (matchPattern(condbr.getCondition(), m_Zero())) {
+ // False branch taken.
+ rewriter.replaceOpWithNewOp<BranchOp>(
+ condbr, condbr.getFalseDest(),
+ llvm::to_vector<4>(condbr.getFalseOperands()));
+ return matchSuccess();
}
-
- rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
- return matchSuccess();
+ return matchFailure();
}
};
} // end anonymous namespace.
OpenPOWER on IntegriCloud