summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/Matchers.h12
-rw-r--r--mlir/lib/IR/Builders.cpp3
-rw-r--r--mlir/test/IR/test-matchers.mlir1
-rw-r--r--mlir/test/lib/IR/TestMatchers.cpp3
4 files changed, 16 insertions, 3 deletions
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index d8d3308c7f0..170984b8550 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -56,6 +56,8 @@ template <typename AttrT> struct constant_op_binder {
/// Creates a matcher instance that binds the constant attribute value to
/// bind_value if match succeeds.
constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
+ /// Creates a matcher instance that doesn't bind if match succeeds.
+ constant_op_binder() : bind_value(nullptr) {}
bool match(Operation *op) {
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
@@ -66,8 +68,11 @@ template <typename AttrT> struct constant_op_binder {
SmallVector<OpFoldResult, 1> foldedOp;
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
- if ((*bind_value = attr.dyn_cast<AttrT>()))
+ if (auto attrT = attr.dyn_cast<AttrT>()) {
+ if (bind_value)
+ *bind_value = attrT;
return true;
+ }
}
}
return false;
@@ -196,6 +201,11 @@ struct RecursivePatternMatcher {
} // end namespace detail
+/// Matches a constant foldable operation.
+inline detail::constant_op_binder<Attribute> m_Constant() {
+ return detail::constant_op_binder<Attribute>();
+}
+
/// Matches a value from a constant foldable operation and writes the value to
/// bind_value.
template <typename AttrT>
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 0c72abf9a5e..50663668275 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -342,8 +342,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
};
// If this operation is already a constant, there is nothing to do.
- Attribute unused;
- if (matchPattern(op, m_Constant(&unused)))
+ if (matchPattern(op, m_Constant()))
return cleanupFailure();
// Check to see if any operands to the operation is constant and whether
diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir
index 7808f25a2f8..60d5bcf7d81 100644
--- a/mlir/test/IR/test-matchers.mlir
+++ b/mlir/test/IR/test-matchers.mlir
@@ -40,3 +40,4 @@ func @test2(%a: f32) -> f32 {
// CHECK-LABEL: test2
// CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
+// CHECK: Pattern add(add(a, constant), a) matched
diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp
index b62daa8437c..6061b251d72 100644
--- a/mlir/test/lib/IR/TestMatchers.cpp
+++ b/mlir/test/lib/IR/TestMatchers.cpp
@@ -126,12 +126,15 @@ void test2(FuncOp f) {
auto a = m_Val(f.getArgument(0));
FloatAttr floatAttr;
auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
+ auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant()));
// Last operation that is not the terminator.
Operation *lastOp = f.getBody().front().back().getPrevNode();
if (p.match(lastOp))
llvm::outs()
<< "Pattern add(add(a, constant), a) matched and bound constant to: "
<< floatAttr.getValueAsDouble() << "\n";
+ if (p1.match(lastOp))
+ llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
}
void TestMatchers::runOnFunction() {
OpenPOWER on IntegriCloud