diff options
Diffstat (limited to 'mlir/lib/Analysis/NestedMatcher.cpp')
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 4f32e9b22f4..491a9bef1b9 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } +static bool isAffineIfOp(const Instruction &inst) { + return isa<OperationInst>(inst) && + cast<OperationInst>(inst).isa<AffineIfOp>(); +} + namespace mlir { namespace matcher { @@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern If(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::If, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern For(NestedPattern child) { |

