diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 245 |
1 files changed, 203 insertions, 42 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 520b89ded48..239915b1d4b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include <iomanip> #define DEBUG_TYPE "loop-fusion" @@ -46,9 +47,16 @@ using llvm::SetVector; using namespace mlir; +/// Disables fusion profitability check and fuses if valid. static llvm::cl::opt<bool> clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden, - llvm::cl::desc("Enables maximal loop fusion.")); + llvm::cl::desc("Enables maximal loop fusion")); + +/// A threshold in percent of additional computation allowed when fusing. +static llvm::cl::opt<double> clFusionAddlComputeTolerance( + "fusion-compute-tolerance", llvm::cl::Hidden, + llvm::cl::desc("Fractional increase in additional" + "computation tolerated while fusing")); namespace { @@ -66,6 +74,10 @@ struct LoopFusion : public FunctionPass { PassResult runOnFunction(Function *f) override; static char passID; + + // The amount of additional computation that is tolerated while fusing + // pair-wise as a fraction of the total computation. + constexpr static double kComputeToleranceThreshold = 0.30f; }; } // end anonymous namespace @@ -496,12 +508,12 @@ public: // inserting a sliced loop nest of known cost into the loop's body. // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. -static uint64_t getComputeCost( +static int64_t getComputeCost( ForInst *forInst, LoopNestStats *stats, llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap, - DenseMap<ForInst *, uint64_t> *computeCostMap) { + DenseMap<ForInst *, int64_t> *computeCostMap) { // 'opCount' is the total number operations in one iteration of 'forInst' body - uint64_t opCount = stats->opCountMap[forInst]; + int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { for (auto *childForInst : stats->loopMap[forInst]) { opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, @@ -516,7 +528,7 @@ static uint64_t getComputeCost( } } // Override trip count (if specified in map). - uint64_t tripCount = stats->tripCountMap[forInst]; + int64_t tripCount = stats->tripCountMap[forInst]; if (tripCountOverrideMap != nullptr) { auto it = tripCountOverrideMap->find(forInst); if (it != tripCountOverrideMap->end()) { @@ -777,6 +789,16 @@ static Value *createPrivateMemRef(ForInst *forInst, return newMemRef; } +// Does the slice have a single iteration? +static uint64_t getSliceIterationCount( + const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) { + uint64_t iterCount = 1; + for (const auto &count : sliceTripCountMap) { + iterCount *= count.second; + } + return iterCount; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. // Returns true if it profitable to fuse the candidate loop nests. Returns @@ -810,6 +832,14 @@ static bool isFusionProfitable(OperationInst *srcOpInst, ArrayRef<OperationInst *> dstOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { + LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n"; + llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n"; + for (auto dstOpInst + : dstOpInsts) { + llvm::dbgs() << " "; + dstOpInst->dump(); + }); + // Compute cost of sliced and unsliced src loop nest. SmallVector<ForInst *, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -845,13 +875,27 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // of these bounds). Next the union slice bounds are used to calculate // the cost of the slice and the cost of the slice inserted into the dst // loop nest at 'dstLoopDepth'. - unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max(); - unsigned bestDstLoopDepth; + uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); + uint64_t maxStorageReduction = 0; + Optional<uint64_t> sliceMemEstimate = None; + SmallVector<ComputationSliceState, 4> sliceStates; sliceStates.resize(maxDstLoopDepth); + // The best loop depth at which to materialize the slice. + Optional<unsigned> bestDstLoopDepth = None; + + // Compute op instance count for the src loop nest without iteration slicing. + uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); + + // Compute op instance count for the src loop nest. + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap; - DenseMap<ForInst *, uint64_t> computeCostMap; + DenseMap<ForInst *, int64_t> computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. @@ -872,56 +916,167 @@ static bool isFusionProfitable(OperationInst *srcOpInst, sliceTripCountMap.clear(); if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], &sliceTripCountMap)) - return false; + // We'll skip cases where we the trip count was non-constant. + continue; - // Compute op instance count for the src loop nest with iteration slicing. - uint64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap, - /*computeCostMap=*/nullptr); + // Checks whether a store to load forwarding will happen. + int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); + + assert(sliceIterationCount > 0); + + // Compute cost of fusion for this dest loop depth. - // Compute cost of fusion for these values of 'i' and 'j'. computeCostMap.clear(); + + // The store and loads to this memref will disappear. + if (storeLoadFwdGuaranteed) { + // A single store disappears: -1 for that. + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; + for (auto *loadOp : dstOpInsts) { + if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst())) + computeCostMap[loadLoop] = -1; + } + } + + // Compute op instance count for the src loop nest with iteration slicing. + int64_t sliceComputeCost = + getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/&sliceTripCountMap, + /*computeCostMap=*/&computeCostMap); + + // Compute cost of fusion for this depth. computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; - uint64_t fusedLoopNestComputeCost = + + int64_t fusedLoopNestComputeCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); - if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { - minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + + double additionalComputeFraction = + fusedLoopNestComputeCost / + (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - + 1; + + // TODO(bondhugula): This is an ugly approximation. Fix this by finding a + // good way to calculate the footprint of the memref in the slice and + // divide it by the total memory footprint of the fused computation. + double storageReduction = + static_cast<double>(srcLoopNestCost) / sliceIterationCount; + + LLVM_DEBUG( + std::stringstream msg; + msg << " evaluating fusion profitability at depth : " << i << "\n" + << std::setprecision(2) << " additional compute fraction: " + << 100.0 * additionalComputeFraction << "%\n" + << " storage reduction factor: " << storageReduction << "x\n" + << " fused nest cost: " << fusedLoopNestComputeCost << "\n" + << " slice iteration count: " << sliceIterationCount << "\n"; + llvm::dbgs() << msg.str()); + + double computeToleranceThreshold = + clFusionAddlComputeTolerance.getNumOccurrences() > 0 + ? clFusionAddlComputeTolerance + : LoopFusion::kComputeToleranceThreshold; + + // TODO(b/123247369): This is a placeholder cost model. + // Among all choices that add an acceptable amount of redundant computation + // (as per computeToleranceThreshold), we will simply pick the one that + // reduces the intermediary size the most. + if ((storageReduction > maxStorageReduction) && + (clMaximalLoopFusion || + (additionalComputeFraction < computeToleranceThreshold))) { + maxStorageReduction = storageReduction; bestDstLoopDepth = i; + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + // TODO(bondhugula,andydavis): find a good way to compute the memory + // footprint of the materialized slice. + // Approximating this to the compute cost of the slice. This could be an + // under-approximation or an overapproximation, but in many cases + // accurate. + sliceMemEstimate = sliceIterationCount; } } - // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); - // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + // A simple cost model: fuse if it reduces the memory footprint. If + // -maximal-fusion is set, fuse nevertheless. - LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics " - << " bestDstLoopDepth: " << bestDstLoopDepth - << " srcLoopNestCost: " << srcLoopNestCost - << " dstLoopNestCost: " << dstLoopNestCost - << " minFusedLoopNestComputeCost: " - << minFusedLoopNestComputeCost << "\n"); - - // Do not fuse if fused loop would increase the total cost of the computation, - // unless 'clMaximalLoopFusion' flag is set. - // TODO(andydavis) Use locality/reduction in slice memref size/opportunity - // for load/store forwarding in cost model. - if (!clMaximalLoopFusion && - minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) + if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << "All fusion choices involve more than the threshold amount of" + "redundant computation; NOT fusing.\n"); return false; + } + + assert(bestDstLoopDepth.hasValue() && + "expected to have a value per logic above"); + + // Set dstLoopDepth based on best values from search. + *dstLoopDepth = bestDstLoopDepth.getValue(); + + LLVM_DEBUG( + llvm::dbgs() << " LoopFusion fusion stats:\n" + << "\n Best loop depth: " << bestDstLoopDepth + << "\n src loop nest compute cost: " << srcLoopNestCost + << "\n dst loop nest compute cost: " << dstLoopNestCost + << "\n fused loop nest compute cost: " + << minFusedLoopNestComputeCost << "\n"); + + auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]); + + Optional<double> storageReduction = None; + + if (!clMaximalLoopFusion) { + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG( + llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } + + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); + + assert(sliceMemEstimate.hasValue() && "expected value"); + // This is an inaccurate estimate since sliceMemEstimate is isaccurate. + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); + + if (fusedMem > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; + } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); + } + + double additionalComputeFraction = + 100.0 * (minFusedLoopNestComputeCost / + (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - + 1); + + std::stringstream msg; + msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " + << setprecision(2) << additionalComputeFraction + << "% redundant computation and a "; + msg << (storageReduction.hasValue() + ? std::to_string(storageReduction.getValue()) + : "<unknown>"); + msg << "% storage reduction.\n"; + LLVM_DEBUG(llvm::dbgs() << msg.str()); + // Update return parameter 'sliceState' with 'bestSliceState'. - ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1]; + ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; sliceState->lbs = bestSliceState->lbs; sliceState->ubs = bestSliceState->ubs; sliceState->lbOperands = bestSliceState->lbOperands; sliceState->ubOperands = bestSliceState->ubOperands; - // Set dstLoopDepth based on best values from search. - *dstLoopDepth = bestDstLoopDepth; + // Canonicalize slice bound affine maps. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { if (sliceState->lbs[i] != AffineMap::Null()) { @@ -1017,29 +1172,35 @@ public: // Skip 'srcEdge' if not for 'memref'. if (srcEdge.memref != memref) continue; + auto *srcNode = mdg->getNode(srcEdge.id); // Skip if 'srcNode' is not a loop nest. if (!isa<ForInst>(srcNode->inst)) continue; + // Skip if 'srcNode' has more than one store to 'memref'. if (srcNode->getStoreOpCount(memref) != 1) continue; + // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly // TODO(andydavis) Track dependence type with edges, and just check // for WAW dependence edge here. if (mdg->getInEdgeCount(srcNode->id, memref) != 0) continue; + // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId) continue; + + // Check if fusion would be profitable. // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); - // Check if fusion would be profitable. unsigned dstLoopDepth; mlir::ComputationSliceState sliceState; if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, &dstLoopDepth)) continue; + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto *sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); |

