summaryrefslogtreecommitdiffstats
path: root/mlir/test/lib
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-12-09 07:47:01 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-09 07:47:35 -0800
commit7b19bd5411a68399db4bcf3c2804a67f1d0b3a62 (patch)
treead244715a81f118d2e8bc16905f8e92878e1ccef /mlir/test/lib
parenta63f6e0bf98f63e5c18acbaf9eacd8fde6a1b001 (diff)
downloadbcm5719-llvm-7b19bd5411a68399db4bcf3c2804a67f1d0b3a62.tar.gz
bcm5719-llvm-7b19bd5411a68399db4bcf3c2804a67f1d0b3a62.zip
Post-submit cleanups in RecursiveMatchers
This CL addresses leftover cleanups and adds a test mixing RecursiveMatchers and m_Constant that captures properly. PiperOrigin-RevId: 284551567
Diffstat (limited to 'mlir/test/lib')
-rw-r--r--mlir/test/lib/IR/TestMatchers.cpp62
1 files changed, 34 insertions, 28 deletions
diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp
index c0b92a8c433..5985a88ffa6 100644
--- a/mlir/test/lib/IR/TestMatchers.cpp
+++ b/mlir/test/lib/IR/TestMatchers.cpp
@@ -24,8 +24,8 @@ using namespace mlir;
namespace {
/// This is a test pass for verifying matchers.
-struct TestMatchers : public ModulePass<TestMatchers> {
- void runOnModule() override;
+struct TestMatchers : public FunctionPass<TestMatchers> {
+ void runOnFunction() override;
};
} // end anonymous namespace
@@ -33,27 +33,20 @@ struct TestMatchers : public ModulePass<TestMatchers> {
template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) {
unsigned count = 0;
f.walk([&count, &matcher](Operation *op) {
- if (matcher.match(op)) {
- // llvm::outs() << "matched " << *op << "\n";
+ if (matcher.match(op))
++count;
- }
});
return count;
}
+using mlir::matchers::m_Any;
+using mlir::matchers::m_Val;
static void test1(FuncOp f) {
- using mlir::matchers::m_any;
- using mlir::matchers::m_val;
-
assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
- auto a = m_val(f.getArgument(0));
- auto b = m_val(f.getArgument(1));
- auto c = m_val(f.getArgument(2));
- (void)a;
- (void)b;
- (void)c;
- llvm::outs() << f.getName();
+ auto a = m_Val(f.getArgument(0));
+ auto b = m_Val(f.getArgument(1));
+ auto c = m_Val(f.getArgument(2));
auto p0 = m_Op<AddFOp>(); // using 0-arity matcher
llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
@@ -63,23 +56,23 @@ static void test1(FuncOp f) {
llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
<< " times\n";
- auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_any());
+ auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any());
llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
<< " times\n";
- auto p3 = m_Op<AddFOp>(m_any(), m_Op<AddFOp>());
+ auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>());
llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
<< " times\n";
- auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_any());
+ auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any());
llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
<< " times\n";
- auto p5 = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
+ auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
<< " times\n";
- auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_any());
+ auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any());
llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
<< " times\n";
@@ -95,7 +88,7 @@ static void test1(FuncOp f) {
// clang-format off
auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
- auto mul_of_anyadd = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
+ auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
mul_of_muladd, m_Op<MulFOp>()),
m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
@@ -128,7 +121,7 @@ static void test1(FuncOp f) {
llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
<< " times\n";
- auto mul_of_aany = m_Op<MulFOp>(a, m_any());
+ auto mul_of_aany = m_Op<MulFOp>(a, m_Any());
auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c));
llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
<< countMatches(f, p16) << " times\n";
@@ -138,12 +131,25 @@ static void test1(FuncOp f) {
<< countMatches(f, p17) << " times\n";
}
-void TestMatchers::runOnModule() {
- auto m = getModule();
- for (auto f : m.getOps<FuncOp>()) {
- if (f.getName() == "test1")
- test1(f);
- }
+void test2(FuncOp f) {
+ auto a = m_Val(f.getArgument(0));
+ FloatAttr floatAttr;
+ auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
+ // Last operation that is not the terminator.
+ Operation *lastOp = f.getBody().front().back().getPrevNode();
+ if (p.match(lastOp))
+ llvm::outs()
+ << "Pattern add(add(a, constant), a) matched and bound constant to: "
+ << floatAttr.getValueAsDouble() << "\n";
+}
+
+void TestMatchers::runOnFunction() {
+ auto f = getFunction();
+ llvm::outs() << f.getName() << "\n";
+ if (f.getName() == "test1")
+ test1(f);
+ if (f.getName() == "test2")
+ test2(f);
}
static PassRegistration<TestMatchers> pass("test-matchers",
OpenPOWER on IntegriCloud