diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-12-09 07:47:01 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-09 07:47:35 -0800 |
| commit | 7b19bd5411a68399db4bcf3c2804a67f1d0b3a62 (patch) | |
| tree | ad244715a81f118d2e8bc16905f8e92878e1ccef /mlir/test/lib | |
| parent | a63f6e0bf98f63e5c18acbaf9eacd8fde6a1b001 (diff) | |
| download | bcm5719-llvm-7b19bd5411a68399db4bcf3c2804a67f1d0b3a62.tar.gz bcm5719-llvm-7b19bd5411a68399db4bcf3c2804a67f1d0b3a62.zip | |
Post-submit cleanups in RecursiveMatchers
This CL addresses leftover cleanups and adds a test mixing RecursiveMatchers and m_Constant
that captures properly.
PiperOrigin-RevId: 284551567
Diffstat (limited to 'mlir/test/lib')
| -rw-r--r-- | mlir/test/lib/IR/TestMatchers.cpp | 62 |
1 files changed, 34 insertions, 28 deletions
diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp index c0b92a8c433..5985a88ffa6 100644 --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -24,8 +24,8 @@ using namespace mlir; namespace { /// This is a test pass for verifying matchers. -struct TestMatchers : public ModulePass<TestMatchers> { - void runOnModule() override; +struct TestMatchers : public FunctionPass<TestMatchers> { + void runOnFunction() override; }; } // end anonymous namespace @@ -33,27 +33,20 @@ struct TestMatchers : public ModulePass<TestMatchers> { template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) { unsigned count = 0; f.walk([&count, &matcher](Operation *op) { - if (matcher.match(op)) { - // llvm::outs() << "matched " << *op << "\n"; + if (matcher.match(op)) ++count; - } }); return count; } +using mlir::matchers::m_Any; +using mlir::matchers::m_Val; static void test1(FuncOp f) { - using mlir::matchers::m_any; - using mlir::matchers::m_val; - assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args"); - auto a = m_val(f.getArgument(0)); - auto b = m_val(f.getArgument(1)); - auto c = m_val(f.getArgument(2)); - (void)a; - (void)b; - (void)c; - llvm::outs() << f.getName(); + auto a = m_Val(f.getArgument(0)); + auto b = m_Val(f.getArgument(1)); + auto c = m_Val(f.getArgument(2)); auto p0 = m_Op<AddFOp>(); // using 0-arity matcher llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0) @@ -63,23 +56,23 @@ static void test1(FuncOp f) { llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1) << " times\n"; - auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_any()); + auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any()); llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2) << " times\n"; - auto p3 = m_Op<AddFOp>(m_any(), m_Op<AddFOp>()); + auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>()); llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3) << " times\n"; - auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_any()); + auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any()); llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4) << " times\n"; - auto p5 = m_Op<MulFOp>(m_any(), m_Op<AddFOp>()); + auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5) << " times\n"; - auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_any()); + auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any()); llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6) << " times\n"; @@ -95,7 +88,7 @@ static void test1(FuncOp f) { // clang-format off auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>()); - auto mul_of_anyadd = m_Op<MulFOp>(m_any(), m_Op<AddFOp>()); + auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); auto p9 = m_Op<MulFOp>(m_Op<MulFOp>( mul_of_muladd, m_Op<MulFOp>()), m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd)); @@ -128,7 +121,7 @@ static void test1(FuncOp f) { llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15) << " times\n"; - auto mul_of_aany = m_Op<MulFOp>(a, m_any()); + auto mul_of_aany = m_Op<MulFOp>(a, m_Any()); auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c)); llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched " << countMatches(f, p16) << " times\n"; @@ -138,12 +131,25 @@ static void test1(FuncOp f) { << countMatches(f, p17) << " times\n"; } -void TestMatchers::runOnModule() { - auto m = getModule(); - for (auto f : m.getOps<FuncOp>()) { - if (f.getName() == "test1") - test1(f); - } +void test2(FuncOp f) { + auto a = m_Val(f.getArgument(0)); + FloatAttr floatAttr; + auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr))); + // Last operation that is not the terminator. + Operation *lastOp = f.getBody().front().back().getPrevNode(); + if (p.match(lastOp)) + llvm::outs() + << "Pattern add(add(a, constant), a) matched and bound constant to: " + << floatAttr.getValueAsDouble() << "\n"; +} + +void TestMatchers::runOnFunction() { + auto f = getFunction(); + llvm::outs() << f.getName() << "\n"; + if (f.getName() == "test1") + test1(f); + if (f.getName() == "test2") + test2(f); } static PassRegistration<TestMatchers> pass("test-matchers", |

