diff options
Diffstat (limited to 'mlir/lib/Analysis/NestedMatcher.cpp')
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 46bf5ad0b97..214b4ce403c 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst, } } +static bool isAffineForOp(const Instruction &inst) { + return cast<OperationInst>(inst).isa<AffineForOp>(); +} + static bool isAffineIfOp(const Instruction &inst) { return isa<OperationInst>(inst) && cast<OperationInst>(inst).isa<AffineIfOp>(); @@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { } NestedPattern For(NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::For, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } NestedPattern For(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::For, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { - const auto *loop = cast<ForInst>(&inst); - return (void *)loop || true; // loop->isParallel(); + auto loop = cast<OperationInst>(inst).cast<AffineForOp>(); + return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { - const auto *loop = cast<ForInst>(&inst); - return (void *)loop || true; // loop->isReduction(); + auto loop = cast<OperationInst>(inst).cast<AffineForOp>(); + return loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { |

