diff options
| author | River Riddle <riverriddle@google.com> | 2019-08-30 12:47:24 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-30 12:47:53 -0700 |
| commit | 037742cdf2b9794941958a985e8a0d2023aaa70d (patch) | |
| tree | 105ffd387ea19e6c2d6226de2e07ba018423289f /mlir/lib/Transforms/Utils | |
| parent | 4f6c29223ee5395dd955cefafce6f03ed99170e0 (diff) | |
| download | bcm5719-llvm-037742cdf2b9794941958a985e8a0d2023aaa70d.tar.gz bcm5719-llvm-037742cdf2b9794941958a985e8a0d2023aaa70d.zip | |
Add support for early exit walk methods.
This is done by providing a walk callback that returns a WalkResult. This result is either `advance` or `interrupt`. `advance` means that the walk should continue, whereas `interrupt` signals that the walk should stop immediately. An example is shown below:
auto result = op->walk([](Operation *op) {
if (some_invariant)
return WalkResult::interrupt();
return WalkResult::advance();
});
if (result.wasInterrupted())
...;
PiperOrigin-RevId: 266436700
Diffstat (limited to 'mlir/lib/Transforms/Utils')
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 99f315e3fd0..8f96cc23fb9 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -114,12 +114,12 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { it != Block::reverse_iterator(opA); ++it) { Operation *opX = &(*it); opX->walk([&](Operation *op) { - if (lastDepOp) - return; if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) { - if (isDependentLoadOrStoreOp(op, values)) + if (isDependentLoadOrStoreOp(op, values)) { lastDepOp = opX; - return; + return WalkResult::interrupt(); + } + return WalkResult::advance(); } for (auto *value : op->getResults()) { for (auto *user : value->getUsers()) { @@ -128,9 +128,11 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { getLoopIVs(*user, &loops); if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { lastDepOp = opX; + return WalkResult::interrupt(); } } } + return WalkResult::advance(); }); if (lastDepOp) break; @@ -257,15 +259,13 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, /// returns false otherwise. bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { - bool ret = true; - forOpRoot.walk([&](AffineForOp forOp) { + auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { auto *childForOp = forOp.getOperation(); auto *parentForOp = forOp.getOperation()->getParentOp(); if (!llvm::isa<FuncOp>(parentForOp)) { if (!isa<AffineForOp>(parentForOp)) { LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); - ret = false; - return; + return WalkResult::interrupt(); } // Add mapping to 'forOp' from its parent AffineForOp. stats->loopMap[parentForOp].push_back(forOp); @@ -279,18 +279,20 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { ++count; } stats->opCountMap[childForOp] = count; + // Record trip count for 'forOp'. Set flag if trip count is not // constant. Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount.hasValue()) { // Currently only constant trip count loop nests are supported. LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); - ret = false; - return; + return WalkResult::interrupt(); } + stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); + return WalkResult::advance(); }); - return ret; + return !walkResult.wasInterrupted(); } // Computes the total cost of the loop nest rooted at 'forOp'. |

