diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/RegionUtils.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/RegionUtils.cpp | 348 |
1 files changed, 348 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp new file mode 100644 index 00000000000..ca26074f288 --- /dev/null +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -0,0 +1,348 @@ +//===- RegionUtils.cpp - Region-related transformation utilities ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/RegionUtils.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; + +void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, + Region ®ion) { + for (auto &use : llvm::make_early_inc_range(orig->getUses())) { + if (region.isAncestor(use.getOwner()->getParentRegion())) + use.set(replacement); + } +} + +void mlir::visitUsedValuesDefinedAbove( + Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { + assert(limit.isAncestor(®ion) && + "expected isolation limit to be an ancestor of the given region"); + + // Collect proper ancestors of `limit` upfront to avoid traversing the region + // tree for every value. + SmallPtrSet<Region *, 4> properAncestors; + for (auto *reg = limit.getParentRegion(); reg != nullptr; + reg = reg->getParentRegion()) { + properAncestors.insert(reg); + } + + region.walk([callback, &properAncestors](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) + // Callback on values defined in a proper ancestor of region. + if (properAncestors.count(operand.get()->getParentRegion())) + callback(&operand); + }); +} + +void mlir::visitUsedValuesDefinedAbove( + MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { + for (Region ®ion : regions) + visitUsedValuesDefinedAbove(region, region, callback); +} + +void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, + llvm::SetVector<Value> &values) { + visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { + values.insert(operand->get()); + }); +} + +void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, + llvm::SetVector<Value> &values) { + for (Region ®ion : regions) + getUsedValuesDefinedAbove(region, region, values); +} + +//===----------------------------------------------------------------------===// +// Unreachable Block Elimination +//===----------------------------------------------------------------------===// + +/// Erase the unreachable blocks within the provided regions. Returns success +/// if any blocks were erased, failure otherwise. +// TODO: We could likely merge this with the DCE algorithm below. +static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) { + // Set of blocks found to be reachable within a given region. + llvm::df_iterator_default_set<Block *, 16> reachable; + // If any blocks were found to be dead. + bool erasedDeadBlocks = false; + + SmallVector<Region *, 1> worklist; + worklist.reserve(regions.size()); + for (Region ®ion : regions) + worklist.push_back(®ion); + while (!worklist.empty()) { + Region *region = worklist.pop_back_val(); + if (region->empty()) + continue; + + // If this is a single block region, just collect the nested regions. + if (std::next(region->begin()) == region->end()) { + for (Operation &op : region->front()) + for (Region ®ion : op.getRegions()) + worklist.push_back(®ion); + continue; + } + + // Mark all reachable blocks. + reachable.clear(); + for (Block *block : depth_first_ext(®ion->front(), reachable)) + (void)block /* Mark all reachable blocks */; + + // Collect all of the dead blocks and push the live regions onto the + // worklist. + for (Block &block : llvm::make_early_inc_range(*region)) { + if (!reachable.count(&block)) { + block.dropAllDefinedValueUses(); + block.erase(); + erasedDeadBlocks = true; + continue; + } + + // Walk any regions within this block. + for (Operation &op : block) + for (Region ®ion : op.getRegions()) + worklist.push_back(®ion); + } + } + + return success(erasedDeadBlocks); +} + +//===----------------------------------------------------------------------===// +// Dead Code Elimination +//===----------------------------------------------------------------------===// + +namespace { +/// Data structure used to track which values have already been proved live. +/// +/// Because Operation's can have multiple results, this data structure tracks +/// liveness for both Value's and Operation's to avoid having to look through +/// all Operation results when analyzing a use. +/// +/// This data structure essentially tracks the dataflow lattice. +/// The set of values/ops proved live increases monotonically to a fixed-point. +class LiveMap { +public: + /// Value methods. + bool wasProvenLive(Value value) { return liveValues.count(value); } + void setProvedLive(Value value) { + changed |= liveValues.insert(value).second; + } + + /// Operation methods. + bool wasProvenLive(Operation *op) { return liveOps.count(op); } + void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } + + /// Methods for tracking if we have reached a fixed-point. + void resetChanged() { changed = false; } + bool hasChanged() { return changed; } + +private: + bool changed = false; + DenseSet<Value> liveValues; + DenseSet<Operation *> liveOps; +}; +} // namespace + +static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { + Operation *owner = use.getOwner(); + unsigned operandIndex = use.getOperandNumber(); + // This pass generally treats all uses of an op as live if the op itself is + // considered live. However, for successor operands to terminators we need a + // finer-grained notion where we deduce liveness for operands individually. + // The reason for this is easiest to think about in terms of a classical phi + // node based SSA IR, where each successor operand is really an operand to a + // *separate* phi node, rather than all operands to the branch itself as with + // the block argument representation that MLIR uses. + // + // And similarly, because each successor operand is really an operand to a phi + // node, rather than to the terminator op itself, a terminator op can't e.g. + // "print" the value of a successor operand. + if (owner->isKnownTerminator()) { + if (auto arg = owner->getSuccessorBlockArgument(operandIndex)) + return !liveMap.wasProvenLive(*arg); + return false; + } + return false; +} + +static void processValue(Value value, LiveMap &liveMap) { + bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) { + if (isUseSpeciallyKnownDead(use, liveMap)) + return false; + return liveMap.wasProvenLive(use.getOwner()); + }); + if (provedLive) + liveMap.setProvedLive(value); +} + +static bool isOpIntrinsicallyLive(Operation *op) { + // This pass doesn't modify the CFG, so terminators are never deleted. + if (!op->isKnownNonTerminator()) + return true; + // If the op has a side effect, we treat it as live. + if (!op->hasNoSideEffect()) + return true; + return false; +} + +static void propagateLiveness(Region ®ion, LiveMap &liveMap); +static void propagateLiveness(Operation *op, LiveMap &liveMap) { + // All Value's are either a block argument or an op result. + // We call processValue on those cases. + + // Recurse on any regions the op has. + for (Region ®ion : op->getRegions()) + propagateLiveness(region, liveMap); + + // Process the op itself. + if (isOpIntrinsicallyLive(op)) { + liveMap.setProvedLive(op); + return; + } + for (Value value : op->getResults()) + processValue(value, liveMap); + bool provedLive = llvm::any_of(op->getResults(), [&](Value value) { + return liveMap.wasProvenLive(value); + }); + if (provedLive) + liveMap.setProvedLive(op); +} + +static void propagateLiveness(Region ®ion, LiveMap &liveMap) { + if (region.empty()) + return; + + for (Block *block : llvm::post_order(®ion.front())) { + // We process block arguments after the ops in the block, to promote + // faster convergence to a fixed point (we try to visit uses before defs). + for (Operation &op : llvm::reverse(block->getOperations())) + propagateLiveness(&op, liveMap); + for (Value value : block->getArguments()) + processValue(value, liveMap); + } +} + +static void eraseTerminatorSuccessorOperands(Operation *terminator, + LiveMap &liveMap) { + for (unsigned succI = 0, succE = terminator->getNumSuccessors(); + succI < succE; succI++) { + // Iterating successors in reverse is not strictly needed, since we + // aren't erasing any successors. But it is slightly more efficient + // since it will promote later operands of the terminator being erased + // first, reducing the quadratic-ness. + unsigned succ = succE - succI - 1; + for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ); + argI < argE; argI++) { + // Iterating args in reverse is needed for correctness, to avoid + // shifting later args when earlier args are erased. + unsigned arg = argE - argI - 1; + Value value = terminator->getSuccessor(succ)->getArgument(arg); + if (!liveMap.wasProvenLive(value)) { + terminator->eraseSuccessorOperand(succ, arg); + } + } + } +} + +static LogicalResult deleteDeadness(MutableArrayRef<Region> regions, + LiveMap &liveMap) { + bool erasedAnything = false; + for (Region ®ion : regions) { + if (region.empty()) + continue; + + // We do the deletion in an order that deletes all uses before deleting + // defs. + // MLIR's SSA structural invariants guarantee that except for block + // arguments, the use-def graph is acyclic, so this is possible with a + // single walk of ops and then a final pass to clean up block arguments. + // + // To do this, we visit ops in an order that visits domtree children + // before domtree parents. A CFG post-order (with reverse iteration with a + // block) satisfies that without needing an explicit domtree calculation. + for (Block *block : llvm::post_order(®ion.front())) { + eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); + for (Operation &childOp : + llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { + erasedAnything |= + succeeded(deleteDeadness(childOp.getRegions(), liveMap)); + if (!liveMap.wasProvenLive(&childOp)) { + erasedAnything = true; + childOp.erase(); + } + } + } + // Delete block arguments. + // The entry block has an unknown contract with their enclosing block, so + // skip it. + for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { + // Iterate in reverse to avoid shifting later arguments when deleting + // earlier arguments. + for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) + if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { + block.eraseArgument(e - i - 1, /*updatePredTerms=*/false); + erasedAnything = true; + } + } + } + return success(erasedAnything); +} + +// This function performs a simple dead code elimination algorithm over the +// given regions. +// +// The overall goal is to prove that Values are dead, which allows deleting ops +// and block arguments. +// +// This uses an optimistic algorithm that assumes everything is dead until +// proved otherwise, allowing it to delete recursively dead cycles. +// +// This is a simple fixed-point dataflow analysis algorithm on a lattice +// {Dead,Alive}. Because liveness flows backward, we generally try to +// iterate everything backward to speed up convergence to the fixed-point. This +// allows for being able to delete recursively dead cycles of the use-def graph, +// including block arguments. +// +// This function returns success if any operations or arguments were deleted, +// failure otherwise. +static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) { + assert(regions.size() == 1); + + LiveMap liveMap; + do { + liveMap.resetChanged(); + + for (Region ®ion : regions) + propagateLiveness(region, liveMap); + } while (liveMap.hasChanged()); + + return deleteDeadness(regions, liveMap); +} + +//===----------------------------------------------------------------------===// +// Region Simplification +//===----------------------------------------------------------------------===// + +/// Run a set of structural simplifications over the given regions. This +/// includes transformations like unreachable block elimination, dead argument +/// elimination, as well as some other DCE. This function returns success if any +/// of the regions were simplified, failure otherwise. +LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) { + LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); + LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); + return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); +} |