summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis/NestedMatcher.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/NestedMatcher.cpp')
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp26
1 files changed, 18 insertions, 8 deletions
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 46bf5ad0b97..214b4ce403c 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -115,6 +115,10 @@ void NestedPattern::matchOne(Instruction *inst,
}
}
+static bool isAffineForOp(const Instruction &inst) {
+ return cast<OperationInst>(inst).isa<AffineForOp>();
+}
+
static bool isAffineIfOp(const Instruction &inst) {
return isa<OperationInst>(inst) &&
cast<OperationInst>(inst).isa<AffineIfOp>();
@@ -147,28 +151,34 @@ NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
}
NestedPattern For(NestedPattern child) {
- return NestedPattern(Instruction::Kind::For, child, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp);
}
NestedPattern For(FilterFunctionType filter, NestedPattern child) {
- return NestedPattern(Instruction::Kind::For, child, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, child,
+ [=](const Instruction &inst) {
+ return isAffineForOp(inst) && filter(inst);
+ });
}
NestedPattern For(ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::For, nested, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp);
}
NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::For, nested, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, nested,
+ [=](const Instruction &inst) {
+ return isAffineForOp(inst) && filter(inst);
+ });
}
// TODO(ntv): parallel annotation on loops.
bool isParallelLoop(const Instruction &inst) {
- const auto *loop = cast<ForInst>(&inst);
- return (void *)loop || true; // loop->isParallel();
+ auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
+ return loop || true; // loop->isParallel();
};
// TODO(ntv): reduction annotation on loops.
bool isReductionLoop(const Instruction &inst) {
- const auto *loop = cast<ForInst>(&inst);
- return (void *)loop || true; // loop->isReduction();
+ auto loop = cast<OperationInst>(inst).cast<AffineForOp>();
+ return loop || true; // loop->isReduction();
};
bool isLoadOrStore(const Instruction &inst) {
OpenPOWER on IntegriCloud