diff options
| -rw-r--r-- | mlir/include/mlir/IR/Matchers.h | 20 | ||||
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 33 |
2 files changed, 30 insertions, 23 deletions
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index a464612da34..aba63d6e221 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -111,16 +111,24 @@ struct constant_int_op_binder { } }; -// The matcher that matches a given target constant scalar / vector splat / -// tensor splat integer value. +/// The matcher that matches a given target constant scalar / vector splat / +/// tensor splat integer value. template <int64_t TargetValue> struct constant_int_value_matcher { bool match(Operation *op) { APInt value; - return constant_int_op_binder(&value).match(op) && TargetValue == value; } }; +/// The matcher that matches anything except the given target constant scalar / +/// vector splat / tensor splat integer value. +template <int64_t TargetNotValue> struct constant_int_not_value_matcher { + bool match(Operation *op) { + APInt value; + return constant_int_op_binder(&value).match(op) && TargetNotValue != value; + } +}; + /// The matcher that matches a certain kind of op. template <typename OpClass> struct op_matcher { bool match(Operation *op) { return isa<OpClass>(op); } @@ -172,6 +180,12 @@ inline detail::constant_int_value_matcher<0> m_Zero() { return detail::constant_int_value_matcher<0>(); } +/// Matches a constant scalar / vector splat / tensor splat integer that is any +/// non-zero value. +inline detail::constant_int_not_value_matcher<0> m_NonZero() { + return detail::constant_int_not_value_matcher<0>(); +} + } // end namespace mlir #endif // MLIR_MATCHERS_H diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 5a452c5242a..161a6c409c7 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1070,27 +1070,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { - // Check that the condition is a constant. - if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>())) - return matchFailure(); - - Block *foldedDest; - SmallVector<Value *, 4> branchArgs; - - // If the condition is known to evaluate to false we fold to a branch to the - // false destination. Otherwise, we fold to a branch to the true - // destination. - if (matchPattern(condbr.getCondition(), m_Zero())) { - foldedDest = condbr.getFalseDest(); - branchArgs.assign(condbr.false_operand_begin(), - condbr.false_operand_end()); - } else { - foldedDest = condbr.getTrueDest(); - branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); + if (matchPattern(condbr.getCondition(), m_NonZero())) { + // True branch taken. + rewriter.replaceOpWithNewOp<BranchOp>( + condbr, condbr.getTrueDest(), + llvm::to_vector<4>(condbr.getTrueOperands())); + return matchSuccess(); + } else if (matchPattern(condbr.getCondition(), m_Zero())) { + // False branch taken. + rewriter.replaceOpWithNewOp<BranchOp>( + condbr, condbr.getFalseDest(), + llvm::to_vector<4>(condbr.getFalseOperands())); + return matchSuccess(); } - - rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs); - return matchSuccess(); + return matchFailure(); } }; } // end anonymous namespace. |

