diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/LoopFusionUtils.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 480 |
1 files changed, 480 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp new file mode 100644 index 00000000000..b0d9fdf5fd8 --- /dev/null +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -0,0 +1,480 @@ +//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop fusion transformation utility functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/LoopFusionUtils.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "loop-fusion-utils" + +using namespace mlir; + +// Gathers all load and store memref accesses in 'opA' into 'values', where +// 'values[memref] == true' for each store operation. +static void getLoadAndStoreMemRefAccesses(Operation *opA, + DenseMap<Value, bool> &values) { + opA->walk([&](Operation *op) { + if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { + if (values.count(loadOp.getMemRef()) == 0) + values[loadOp.getMemRef()] = false; + } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + values[storeOp.getMemRef()] = true; + } + }); +} + +// Returns true if 'op' is a load or store operation which access an memref +// accessed 'values' and at least one of the access is a store operation. +// Returns false otherwise. +static bool isDependentLoadOrStoreOp(Operation *op, + DenseMap<Value, bool> &values) { + if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { + return values.count(loadOp.getMemRef()) > 0 && + values[loadOp.getMemRef()] == true; + } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + return values.count(storeOp.getMemRef()) > 0; + } + return false; +} + +// Returns the first operation in range ('opA', 'opB') which has a data +// dependence on 'opA'. Returns 'nullptr' of no dependence exists. +static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { + // Record memref values from all loads/store in loop nest rooted at 'opA'. + // Map from memref value to bool which is true if store, false otherwise. + DenseMap<Value, bool> values; + getLoadAndStoreMemRefAccesses(opA, values); + + // For each 'opX' in block in range ('opA', 'opB'), check if there is a data + // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref + // and at least one of the accesses is a store). + Operation *firstDepOp = nullptr; + for (Block::iterator it = std::next(Block::iterator(opA)); + it != Block::iterator(opB); ++it) { + Operation *opX = &(*it); + opX->walk([&](Operation *op) { + if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) + firstDepOp = opX; + }); + if (firstDepOp) + break; + } + return firstDepOp; +} + +// Returns the last operation 'opX' in range ('opA', 'opB'), for which there +// exists a data dependence from 'opX' to 'opB'. +// Returns 'nullptr' of no dependence exists. +static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { + // Record memref values from all loads/store in loop nest rooted at 'opB'. + // Map from memref value to bool which is true if store, false otherwise. + DenseMap<Value, bool> values; + getLoadAndStoreMemRefAccesses(opB, values); + + // For each 'opX' in block in range ('opA', 'opB') in reverse order, + // check if there is a data dependence from 'opX' to 'opB': + // *) 'opX' and 'opB' access the same memref and at least one of the accesses + // is a store. + // *) 'opX' produces an SSA Value which is used by 'opB'. + Operation *lastDepOp = nullptr; + for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); + it != Block::reverse_iterator(opA); ++it) { + Operation *opX = &(*it); + opX->walk([&](Operation *op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) { + if (isDependentLoadOrStoreOp(op, values)) { + lastDepOp = opX; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + for (auto value : op->getResults()) { + for (auto user : value->getUsers()) { + SmallVector<AffineForOp, 4> loops; + // Check if any loop in loop nest surrounding 'user' is 'opB'. + getLoopIVs(*user, &loops); + if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { + lastDepOp = opX; + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + if (lastDepOp) + break; + } + return lastDepOp; +} + +// Computes and returns an insertion point operation, before which the +// the fused <srcForOp, dstForOp> loop nest can be inserted while preserving +// dependences. Returns nullptr if no such insertion point is found. +static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, + AffineForOp dstForOp) { + bool isSrcForOpBeforeDstForOp = + srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); + auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; + auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; + + auto *firstDepOpA = + getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + auto *lastDepOpB = + getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + // Block: + // ... + // |-- opA + // | ... + // | lastDepOpB --| + // | ... | + // |-> firstDepOpA | + // ... | + // opB <--------- + // + // Valid insertion point range: (lastDepOpB, firstDepOpA) + // + if (firstDepOpA != nullptr) { + if (lastDepOpB != nullptr) { + if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) + // No valid insertion point exists which preserves dependences. + return nullptr; + } + // Return insertion point in valid range closest to 'opB'. + // TODO(andydavis) Consider other insertion points in valid range. + return firstDepOpA; + } + // No dependences from 'opA' to operation in range ('opA', 'opB'), return + // 'opB' insertion point. + return forOpB.getOperation(); +} + +// Gathers all load and store ops in loop nest rooted at 'forOp' into +// 'loadAndStoreOps'. +static bool +gatherLoadsAndStores(AffineForOp forOp, + SmallVectorImpl<Operation *> &loadAndStoreOps) { + bool hasIfOp = false; + forOp.walk([&](Operation *op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) + loadAndStoreOps.push_back(op); + else if (isa<AffineIfOp>(op)) + hasIfOp = true; + }); + return !hasIfOp; +} + +// TODO(andydavis) Prevent fusion of loop nests with side-effecting operations. +FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, + unsigned dstLoopDepth, + ComputationSliceState *srcSlice) { + // Return 'failure' if 'dstLoopDepth == 0'. + if (dstLoopDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n."); + return FusionResult::FailPrecondition; + } + // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. + auto *block = srcForOp.getOperation()->getBlock(); + if (block != dstForOp.getOperation()->getBlock()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); + return FusionResult::FailPrecondition; + } + + // Return 'failure' if no valid insertion point for fused loop nest in 'block' + // exists which would preserve dependences. + if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); + return FusionResult::FailBlockDependence; + } + + // Check if 'srcForOp' precedes 'dstForOp' in 'block'. + bool isSrcForOpBeforeDstForOp = + srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); + // 'forOpA' executes before 'forOpB' in 'block'. + auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; + auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; + + // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. + SmallVector<Operation *, 4> opsA; + if (!gatherLoadsAndStores(forOpA, opsA)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. + SmallVector<Operation *, 4> opsB; + if (!gatherLoadsAndStores(forOpB, opsB)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. + unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( + *srcForOp.getOperation(), *dstForOp.getOperation()); + + // Compute union of computation slices computed between all pairs of ops + // from 'forOpA' and 'forOpB'. + if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops, + isSrcForOpBeforeDstForOp, srcSlice))) { + LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); + return FusionResult::FailPrecondition; + } + + return FusionResult::Success; +} + +/// Collect loop nest statistics (eg. loop trip count and operation count) +/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, +/// returns false otherwise. +bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { + auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { + auto *childForOp = forOp.getOperation(); + auto *parentForOp = forOp.getParentOp(); + if (!llvm::isa<FuncOp>(parentForOp)) { + if (!isa<AffineForOp>(parentForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); + return WalkResult::interrupt(); + } + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentForOp].push_back(forOp); + } + + // Record the number of op operations in the body of 'forOp'. + unsigned count = 0; + stats->opCountMap[childForOp] = 0; + for (auto &op : *forOp.getBody()) { + if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op)) + ++count; + } + stats->opCountMap[childForOp] = count; + + // Record trip count for 'forOp'. Set flag if trip count is not + // constant. + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); + if (!maybeConstTripCount.hasValue()) { + // Currently only constant trip count loop nests are supported. + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); + return WalkResult::interrupt(); + } + + stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +// Computes the total cost of the loop nest rooted at 'forOp'. +// Currently, the total cost is computed by counting the total operation +// instance count (i.e. total number of operations in the loop bodyloop +// operation count * loop trip count) for the entire loop nest. +// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops +// specified in the map when computing the total op instance count. +// NOTEs: 1) This is used to compute the cost of computation slices, which are +// sliced along the iteration dimension, and thus reduce the trip count. +// If 'computeCostMap' is non-null, the total op count for forOps specified +// in the map is increased (not overridden) by adding the op count from the +// map to the existing op count for the for loop. This is done before +// multiplying by the loop's trip count, and is used to model the cost of +// inserting a sliced loop nest of known cost into the loop's body. +// 2) This is also used to compute the cost of fusing a slice of some loop nest +// within another loop. +static int64_t getComputeCostHelper( + Operation *forOp, LoopNestStats &stats, + llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, + DenseMap<Operation *, int64_t> *computeCostMap) { + // 'opCount' is the total number operations in one iteration of 'forOp' body, + // minus terminator op which is a no-op. + int64_t opCount = stats.opCountMap[forOp] - 1; + if (stats.loopMap.count(forOp) > 0) { + for (auto childForOp : stats.loopMap[forOp]) { + opCount += getComputeCostHelper(childForOp.getOperation(), stats, + tripCountOverrideMap, computeCostMap); + } + } + // Add in additional op instances from slice (if specified in map). + if (computeCostMap != nullptr) { + auto it = computeCostMap->find(forOp); + if (it != computeCostMap->end()) { + opCount += it->second; + } + } + // Override trip count (if specified in map). + int64_t tripCount = stats.tripCountMap[forOp]; + if (tripCountOverrideMap != nullptr) { + auto it = tripCountOverrideMap->find(forOp); + if (it != tripCountOverrideMap->end()) { + tripCount = it->second; + } + } + // Returns the total number of dynamic instances of operations in loop body. + return tripCount * opCount; +} + +// TODO(andydavis,b/126426796): extend this to handle multiple result maps. +static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) { + assert(lbMap.getNumResults() == 1 && "expected single result bound map"); + assert(ubMap.getNumResults() == 1 && "expected single result bound map"); + assert(lbMap.getNumDims() == ubMap.getNumDims()); + assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); + AffineExpr lbExpr(lbMap.getResult(0)); + AffineExpr ubExpr(ubMap.getResult(0)); + auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), + lbMap.getNumSymbols()); + auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>(); + if (!cExpr) + return None; + return cExpr.getValue(); +} + +// Return the number of iterations in the given slice. +static uint64_t getSliceIterationCount( + const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) { + uint64_t iterCount = 1; + for (const auto &count : sliceTripCountMap) { + iterCount *= count.second; + } + return iterCount; +} + +// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop +// nest surrounding represented by slice loop bounds in 'slice'. +// Returns true on success, false otherwise (if a non-constant trip count +// was encountered). +// TODO(andydavis) Make this work with non-unit step loops. +static bool buildSliceTripCountMap( + ComputationSliceState *slice, + llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) { + unsigned numSrcLoopIVs = slice->ivs.size(); + // Populate map from AffineForOp -> trip count + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]); + auto *op = forOp.getOperation(); + AffineMap lbMap = slice->lbs[i]; + AffineMap ubMap = slice->ubs[i]; + if (lbMap == AffineMap() || ubMap == AffineMap()) { + // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. + if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { + (*tripCountMap)[op] = + forOp.getConstantUpperBound() - forOp.getConstantLowerBound(); + continue; + } + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); + if (maybeConstTripCount.hasValue()) { + (*tripCountMap)[op] = maybeConstTripCount.getValue(); + continue; + } + return false; + } + Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap); + // Slice bounds are created with a constant ub - lb difference. + if (!tripCount.hasValue()) + return false; + (*tripCountMap)[op] = tripCount.getValue(); + } + return true; +} + +/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. +/// Currently, the total cost is computed by counting the total operation +/// instance count (i.e. total number of operations in the loop body * loop +/// trip count) for the entire loop nest. +int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { + return getComputeCostHelper(forOp.getOperation(), stats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); +} + +/// Computes and returns in 'computeCost', the total compute cost of fusing the +/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, +/// the total cost is computed by counting the total operation instance count +/// (i.e. total number of operations in the loop body * loop trip count) for +/// the entire loop nest. +bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, + AffineForOp dstForOp, LoopNestStats &dstStats, + ComputationSliceState *slice, + int64_t *computeCost) { + llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; + DenseMap<Operation *, int64_t> computeCostMap; + + // Build trip count map for computation slice. + if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) + return false; + // Checks whether a store to load forwarding will happen. + int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); + assert(sliceIterationCount > 0); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); + auto *insertPointParent = slice->insertPoint->getParentOp(); + + // The store and loads to this memref will disappear. + // TODO(andydavis) Add load coalescing to memref data flow opt pass. + if (storeLoadFwdGuaranteed) { + // Subtract from operation count the loads/store we expect load/store + // forwarding to remove. + unsigned storeCount = 0; + llvm::SmallDenseSet<Value, 4> storeMemrefs; + srcForOp.walk([&](Operation *op) { + if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { + storeMemrefs.insert(storeOp.getMemRef()); + ++storeCount; + } + }); + // Subtract out any store ops in single-iteration src slice loop nest. + if (storeCount > 0) + computeCostMap[insertPointParent] = -storeCount; + // Subtract out any load users of 'storeMemrefs' nested below + // 'insertPointParent'. + for (auto value : storeMemrefs) { + for (auto *user : value->getUsers()) { + if (auto loadOp = dyn_cast<AffineLoadOp>(user)) { + SmallVector<AffineForOp, 4> loops; + // Check if any loop in loop nest surrounding 'user' is + // 'insertPointParent'. + getLoopIVs(*user, &loops); + if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { + if (auto forOp = + dyn_cast_or_null<AffineForOp>(user->getParentOp())) { + if (computeCostMap.count(forOp) == 0) + computeCostMap[forOp] = 0; + computeCostMap[forOp] -= 1; + } + } + } + } + } + } + + // Compute op instance count for the src loop nest with iteration slicing. + int64_t sliceComputeCost = getComputeCostHelper( + srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); + + // Compute cost of fusion for this depth. + computeCostMap[insertPointParent] = sliceComputeCost; + + *computeCost = + getComputeCostHelper(dstForOp.getOperation(), dstStats, + /*tripCountOverrideMap=*/nullptr, &computeCostMap); + return true; +} |