summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis/NestedMatcher.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/NestedMatcher.cpp')
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp20
1 files changed, 16 insertions, 4 deletions
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 4f32e9b22f4..491a9bef1b9 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/ArrayRef.h"
@@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() {
return storage->filter;
}
+static bool isAffineIfOp(const Instruction &inst) {
+ return isa<OperationInst>(inst) &&
+ cast<OperationInst>(inst).isa<AffineIfOp>();
+}
+
namespace mlir {
namespace matcher {
@@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) {
}
NestedPattern If(NestedPattern child) {
- return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp);
}
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
- return NestedPattern(Instruction::Kind::If, child, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, child,
+ [filter](const Instruction &inst) {
+ return isAffineIfOp(inst) && filter(inst);
+ });
}
NestedPattern If(ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp);
}
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::If, nested, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, nested,
+ [filter](const Instruction &inst) {
+ return isAffineIfOp(inst) && filter(inst);
+ });
}
NestedPattern For(NestedPattern child) {
OpenPOWER on IntegriCloud