//===- RegionUtils.cpp - Region-related transformation utilities ----------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Transforms/RegionUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/Operation.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallSet.h" using namespace mlir; void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion) { for (auto &use : llvm::make_early_inc_range(orig.getUses())) { if (region.isAncestor(use.getOwner()->getParentRegion())) use.set(replacement); } } void mlir::visitUsedValuesDefinedAbove( Region ®ion, Region &limit, function_ref callback) { assert(limit.isAncestor(®ion) && "expected isolation limit to be an ancestor of the given region"); // Collect proper ancestors of `limit` upfront to avoid traversing the region // tree for every value. SmallPtrSet properAncestors; for (auto *reg = limit.getParentRegion(); reg != nullptr; reg = reg->getParentRegion()) { properAncestors.insert(reg); } region.walk([callback, &properAncestors](Operation *op) { for (OpOperand &operand : op->getOpOperands()) // Callback on values defined in a proper ancestor of region. if (properAncestors.count(operand.get().getParentRegion())) callback(&operand); }); } void mlir::visitUsedValuesDefinedAbove( MutableArrayRef regions, function_ref callback) { for (Region ®ion : regions) visitUsedValuesDefinedAbove(region, region, callback); } void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, llvm::SetVector &values) { visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { values.insert(operand->get()); }); } void mlir::getUsedValuesDefinedAbove(MutableArrayRef regions, llvm::SetVector &values) { for (Region ®ion : regions) getUsedValuesDefinedAbove(region, region, values); } //===----------------------------------------------------------------------===// // Unreachable Block Elimination //===----------------------------------------------------------------------===// /// Erase the unreachable blocks within the provided regions. Returns success /// if any blocks were erased, failure otherwise. // TODO: We could likely merge this with the DCE algorithm below. static LogicalResult eraseUnreachableBlocks(MutableArrayRef regions) { // Set of blocks found to be reachable within a given region. llvm::df_iterator_default_set reachable; // If any blocks were found to be dead. bool erasedDeadBlocks = false; SmallVector worklist; worklist.reserve(regions.size()); for (Region ®ion : regions) worklist.push_back(®ion); while (!worklist.empty()) { Region *region = worklist.pop_back_val(); if (region->empty()) continue; // If this is a single block region, just collect the nested regions. if (std::next(region->begin()) == region->end()) { for (Operation &op : region->front()) for (Region ®ion : op.getRegions()) worklist.push_back(®ion); continue; } // Mark all reachable blocks. reachable.clear(); for (Block *block : depth_first_ext(®ion->front(), reachable)) (void)block /* Mark all reachable blocks */; // Collect all of the dead blocks and push the live regions onto the // worklist. for (Block &block : llvm::make_early_inc_range(*region)) { if (!reachable.count(&block)) { block.dropAllDefinedValueUses(); block.erase(); erasedDeadBlocks = true; continue; } // Walk any regions within this block. for (Operation &op : block) for (Region ®ion : op.getRegions()) worklist.push_back(®ion); } } return success(erasedDeadBlocks); } //===----------------------------------------------------------------------===// // Dead Code Elimination //===----------------------------------------------------------------------===// namespace { /// Data structure used to track which values have already been proved live. /// /// Because Operation's can have multiple results, this data structure tracks /// liveness for both Value's and Operation's to avoid having to look through /// all Operation results when analyzing a use. /// /// This data structure essentially tracks the dataflow lattice. /// The set of values/ops proved live increases monotonically to a fixed-point. class LiveMap { public: /// Value methods. bool wasProvenLive(Value value) { return liveValues.count(value); } void setProvedLive(Value value) { changed |= liveValues.insert(value).second; } /// Operation methods. bool wasProvenLive(Operation *op) { return liveOps.count(op); } void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } /// Methods for tracking if we have reached a fixed-point. void resetChanged() { changed = false; } bool hasChanged() { return changed; } private: bool changed = false; DenseSet liveValues; DenseSet liveOps; }; } // namespace static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { Operation *owner = use.getOwner(); unsigned operandIndex = use.getOperandNumber(); // This pass generally treats all uses of an op as live if the op itself is // considered live. However, for successor operands to terminators we need a // finer-grained notion where we deduce liveness for operands individually. // The reason for this is easiest to think about in terms of a classical phi // node based SSA IR, where each successor operand is really an operand to a // *separate* phi node, rather than all operands to the branch itself as with // the block argument representation that MLIR uses. // // And similarly, because each successor operand is really an operand to a phi // node, rather than to the terminator op itself, a terminator op can't e.g. // "print" the value of a successor operand. if (owner->isKnownTerminator()) { if (auto arg = owner->getSuccessorBlockArgument(operandIndex)) return !liveMap.wasProvenLive(*arg); return false; } return false; } static void processValue(Value value, LiveMap &liveMap) { bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { if (isUseSpeciallyKnownDead(use, liveMap)) return false; return liveMap.wasProvenLive(use.getOwner()); }); if (provedLive) liveMap.setProvedLive(value); } static bool isOpIntrinsicallyLive(Operation *op) { // This pass doesn't modify the CFG, so terminators are never deleted. if (!op->isKnownNonTerminator()) return true; // If the op has a side effect, we treat it as live. if (!op->hasNoSideEffect()) return true; return false; } static void propagateLiveness(Region ®ion, LiveMap &liveMap); static void propagateLiveness(Operation *op, LiveMap &liveMap) { // All Value's are either a block argument or an op result. // We call processValue on those cases. // Recurse on any regions the op has. for (Region ®ion : op->getRegions()) propagateLiveness(region, liveMap); // Process the op itself. if (isOpIntrinsicallyLive(op)) { liveMap.setProvedLive(op); return; } for (Value value : op->getResults()) processValue(value, liveMap); bool provedLive = llvm::any_of(op->getResults(), [&](Value value) { return liveMap.wasProvenLive(value); }); if (provedLive) liveMap.setProvedLive(op); } static void propagateLiveness(Region ®ion, LiveMap &liveMap) { if (region.empty()) return; for (Block *block : llvm::post_order(®ion.front())) { // We process block arguments after the ops in the block, to promote // faster convergence to a fixed point (we try to visit uses before defs). for (Operation &op : llvm::reverse(block->getOperations())) propagateLiveness(&op, liveMap); for (Value value : block->getArguments()) processValue(value, liveMap); } } static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap) { for (unsigned succI = 0, succE = terminator->getNumSuccessors(); succI < succE; succI++) { // Iterating successors in reverse is not strictly needed, since we // aren't erasing any successors. But it is slightly more efficient // since it will promote later operands of the terminator being erased // first, reducing the quadratic-ness. unsigned succ = succE - succI - 1; for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ); argI < argE; argI++) { // Iterating args in reverse is needed for correctness, to avoid // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; Value value = terminator->getSuccessor(succ)->getArgument(arg); if (!liveMap.wasProvenLive(value)) { terminator->eraseSuccessorOperand(succ, arg); } } } } static LogicalResult deleteDeadness(MutableArrayRef regions, LiveMap &liveMap) { bool erasedAnything = false; for (Region ®ion : regions) { if (region.empty()) continue; // We do the deletion in an order that deletes all uses before deleting // defs. // MLIR's SSA structural invariants guarantee that except for block // arguments, the use-def graph is acyclic, so this is possible with a // single walk of ops and then a final pass to clean up block arguments. // // To do this, we visit ops in an order that visits domtree children // before domtree parents. A CFG post-order (with reverse iteration with a // block) satisfies that without needing an explicit domtree calculation. for (Block *block : llvm::post_order(®ion.front())) { eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); for (Operation &childOp : llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { erasedAnything |= succeeded(deleteDeadness(childOp.getRegions(), liveMap)); if (!liveMap.wasProvenLive(&childOp)) { erasedAnything = true; childOp.erase(); } } } // Delete block arguments. // The entry block has an unknown contract with their enclosing block, so // skip it. for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { // Iterate in reverse to avoid shifting later arguments when deleting // earlier arguments. for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { block.eraseArgument(e - i - 1, /*updatePredTerms=*/false); erasedAnything = true; } } } return success(erasedAnything); } // This function performs a simple dead code elimination algorithm over the // given regions. // // The overall goal is to prove that Values are dead, which allows deleting ops // and block arguments. // // This uses an optimistic algorithm that assumes everything is dead until // proved otherwise, allowing it to delete recursively dead cycles. // // This is a simple fixed-point dataflow analysis algorithm on a lattice // {Dead,Alive}. Because liveness flows backward, we generally try to // iterate everything backward to speed up convergence to the fixed-point. This // allows for being able to delete recursively dead cycles of the use-def graph, // including block arguments. // // This function returns success if any operations or arguments were deleted, // failure otherwise. static LogicalResult runRegionDCE(MutableArrayRef regions) { LiveMap liveMap; do { liveMap.resetChanged(); for (Region ®ion : regions) propagateLiveness(region, liveMap); } while (liveMap.hasChanged()); return deleteDeadness(regions, liveMap); } //===----------------------------------------------------------------------===// // Region Simplification //===----------------------------------------------------------------------===// /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. LogicalResult mlir::simplifyRegions(MutableArrayRef regions) { LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); }