diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Analysis/Utils.h | 11 | ||||
| -rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 107 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DmaGeneration.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 245 | ||||
| -rw-r--r-- | mlir/test/Transforms/loop-fusion.mlir | 29 |
6 files changed, 328 insertions, 78 deletions
diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index bfdf4d40b34..8c8f73da409 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -76,9 +76,10 @@ struct MemRefRegion { void setWrite(bool flag) { write = flag; } /// Returns a constant upper bound on the number of elements in this region if - /// bounded by a known constant, None otherwise. The 'shape' vector is set to - /// the corresponding dimension-wise bounds major to minor. We use int64_t - /// instead of uint64_t since index types can be at most int64_t. + /// bounded by a known constant (always possible for static shapes), None + /// otherwise. The 'shape' vector is set to the corresponding dimension-wise + /// bounds major to minor. We use int64_t instead of uint64_t since index + /// types can be at most int64_t. Optional<int64_t> getConstantBoundingSizeAndShape( SmallVectorImpl<int> *shape = nullptr, std::vector<SmallVector<int64_t, 4>> *lbs = nullptr, @@ -192,6 +193,10 @@ ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState); + +Optional<int64_t> getMemoryFootprintBytes(const ForInst &forInst, + int memorySpace = -1); + } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 44daaf1459b..baab283ab25 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -395,7 +395,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { FlatAffineConstraints cst; if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { LLVM_DEBUG(llvm::dbgs() - << "composition unimplemented for semi-affine maps"); + << "composition unimplemented for semi-affine maps\n"); return false; } assert(flatExprs.size() == vMap->getNumResults()); @@ -823,6 +823,9 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, if (posStart >= posLimit) return 0; + LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", " + << posLimit << ")\n"); + GCDTightenInequalities(); unsigned pivotCol = 0; @@ -1749,6 +1752,9 @@ getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { return {newNumDims, newNumSymbols}; } +#undef DEBUG_TYPE +#define DEBUG_TYPE "fm" + /// Eliminates identifier at the specified position using Fourier-Motzkin /// variable elimination. This technique is exact for rational spaces but /// conservative (in "rare" cases) for integer spaces. The operation corresponds @@ -1951,6 +1957,9 @@ void FlatAffineConstraints::FourierMotzkinEliminate( LLVM_DEBUG(dump()); } +#undef DEBUG_TYPE +#define DEBUG_TYPE "affine-structures" + void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { if (num == 0) return; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 592fad4ab29..79d1b696612 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -60,7 +60,8 @@ Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape( SmallVectorImpl<int64_t> *lbDivisors) const { auto memRefType = memref->getType().cast<MemRefType>(); unsigned rank = memRefType.getRank(); - shape->reserve(rank); + if (shape) + shape->reserve(rank); // Find a constant upper bound on the extent of this memref region along each // dimension. @@ -189,6 +190,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // Add access function equalities to connect loop IVs to data dimensions. if (!regionCst->composeMap(&accessValueMap)) { LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n"); + LLVM_DEBUG(accessValueMap.getAffineMap().dump()); return false; } @@ -207,14 +209,13 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, } // Project out any local variables (these would have been added for any // mod/divs). - regionCst->projectOut(regionCst->getNumDimIds() + - regionCst->getNumSymbolIds(), + regionCst->projectOut(regionCst->getNumDimAndSymbolIds(), regionCst->getNumLocalIds()); // Set all identifiers appearing after the first 'rank' identifiers as // symbolic identifiers - so that the ones correspoding to the memref // dimensions are the dimensional identifiers for the memref region. - regionCst->setDimSymbolSeparation(regionCst->getNumIds() - rank); + regionCst->setDimSymbolSeparation(regionCst->getNumDimAndSymbolIds() - rank); // Constant fold any symbolic identifiers. regionCst->constantFoldIdRange(/*pos=*/regionCst->getNumDimIds(), @@ -222,12 +223,31 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, assert(regionCst->getNumDimIds() == rank && "unexpected MemRefRegion format"); + LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); + LLVM_DEBUG(region->getConstraints()->dump()); + return true; } +// TODO(mlir-team): improve/complete this when we have target data. +static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast<VectorType>(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + /// Returns the size of memref data in bytes if it's statically shaped, None /// otherwise. If the element of the memref has vector type, takes into account /// size of the vector as well. +// TODO(mlir-team): improve/complete this when we have target data. Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) { if (memRefType.getNumDynamicDims() > 0) return None; @@ -235,18 +255,11 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) { if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) return None; - uint64_t sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast<VectorType>(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } + unsigned sizeInBytes = getMemRefEltSizeInBytes(memRefType); for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) { - sizeInBits = sizeInBits * memRefType.getDimSize(i); + sizeInBytes = sizeInBytes * memRefType.getDimSize(i); } - return llvm::divideCeil(sizeInBits, 8); + return sizeInBytes; } template <typename LoadOrStoreOpPointer> @@ -525,3 +538,69 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, } return numCommonLoops; } + +// Returns the size of the region. +static Optional<int64_t> getRegionSize(const MemRefRegion ®ion) { + auto *memref = region.memref; + auto memRefType = memref->getType().cast<MemRefType>(); + + auto layoutMaps = memRefType.getAffineMaps(); + if (layoutMaps.size() > 1 || + (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + return false; + } + + // Indices to use for the DmaStart op. + // Indices for the original memref being DMAed from/to. + SmallVector<Value *, 4> memIndices; + // Indices for the faster buffer being DMAed into/from. + SmallVector<Value *, 4> bufIndices; + + // Compute the extents of the buffer. + Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(); + if (!numElements.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + return None; + } + return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); +} + +Optional<int64_t> mlir::getMemoryFootprintBytes(const ForInst &forInst, + int memorySpace) { + std::vector<std::unique_ptr<MemRefRegion>> regions; + + // Walk this 'for' instruction to gather all memory regions. + bool error = false; + const_cast<ForInst *>(&forInst)->walkOps([&](OperationInst *opInst) { + if (!opInst->isa<LoadOp>() && !opInst->isa<StoreOp>()) { + // Neither load nor a store op. + return; + } + + // TODO(bondhugula): eventually, we need to be performing a union across + // all regions for a given memref instead of creating one region per + // memory op. This way we would be allocating O(num of memref's) sets + // instead of O(num of load/store op's). + auto region = std::make_unique<MemRefRegion>(); + if (!getMemRefRegion(opInst, 0, region.get())) { + LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n"); + // TODO: stop the walk if an error occurred. + error = true; + return; + } + regions.push_back(std::move(region)); + }); + + if (error) + return None; + + int64_t totalSizeInBytes = 0; + for (const auto ®ion : regions) { + auto size = getRegionSize(*region); + if (!size.hasValue()) + return None; + totalSizeInBytes += size.getValue(); + } + return totalSizeInBytes; +} diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 3b829fc55e5..8b86056c8a9 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -179,18 +179,15 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst, &fastBufferShape, &lbs, &lbDivisors); if (!numElements.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); - *sizeInBytes = 0; return false; } if (numElements.getValue() == 0) { LLVM_DEBUG(llvm::dbgs() << "Nothing to DMA\n"); - *sizeInBytes = 0; return false; } const FlatAffineConstraints *cst = region.getConstraints(); - // 'outerIVs' holds the values that this memory region is symbolic/paramteric // on; this would correspond to loop IVs surrounding the level at which the // DMA generation is being done. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 520b89ded48..239915b1d4b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include <iomanip> #define DEBUG_TYPE "loop-fusion" @@ -46,9 +47,16 @@ using llvm::SetVector; using namespace mlir; +/// Disables fusion profitability check and fuses if valid. static llvm::cl::opt<bool> clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden, - llvm::cl::desc("Enables maximal loop fusion.")); + llvm::cl::desc("Enables maximal loop fusion")); + +/// A threshold in percent of additional computation allowed when fusing. +static llvm::cl::opt<double> clFusionAddlComputeTolerance( + "fusion-compute-tolerance", llvm::cl::Hidden, + llvm::cl::desc("Fractional increase in additional" + "computation tolerated while fusing")); namespace { @@ -66,6 +74,10 @@ struct LoopFusion : public FunctionPass { PassResult runOnFunction(Function *f) override; static char passID; + + // The amount of additional computation that is tolerated while fusing + // pair-wise as a fraction of the total computation. + constexpr static double kComputeToleranceThreshold = 0.30f; }; } // end anonymous namespace @@ -496,12 +508,12 @@ public: // inserting a sliced loop nest of known cost into the loop's body. // NOTE: this is used to compute the cost of fusing a slice of some loop nest // within another loop. -static uint64_t getComputeCost( +static int64_t getComputeCost( ForInst *forInst, LoopNestStats *stats, llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap, - DenseMap<ForInst *, uint64_t> *computeCostMap) { + DenseMap<ForInst *, int64_t> *computeCostMap) { // 'opCount' is the total number operations in one iteration of 'forInst' body - uint64_t opCount = stats->opCountMap[forInst]; + int64_t opCount = stats->opCountMap[forInst]; if (stats->loopMap.count(forInst) > 0) { for (auto *childForInst : stats->loopMap[forInst]) { opCount += getComputeCost(childForInst, stats, tripCountOverrideMap, @@ -516,7 +528,7 @@ static uint64_t getComputeCost( } } // Override trip count (if specified in map). - uint64_t tripCount = stats->tripCountMap[forInst]; + int64_t tripCount = stats->tripCountMap[forInst]; if (tripCountOverrideMap != nullptr) { auto it = tripCountOverrideMap->find(forInst); if (it != tripCountOverrideMap->end()) { @@ -777,6 +789,16 @@ static Value *createPrivateMemRef(ForInst *forInst, return newMemRef; } +// Does the slice have a single iteration? +static uint64_t getSliceIterationCount( + const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) { + uint64_t iterCount = 1; + for (const auto &count : sliceTripCountMap) { + iterCount *= count.second; + } + return iterCount; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'. // Returns true if it profitable to fuse the candidate loop nests. Returns @@ -810,6 +832,14 @@ static bool isFusionProfitable(OperationInst *srcOpInst, ArrayRef<OperationInst *> dstOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { + LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n"; + llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n"; + for (auto dstOpInst + : dstOpInsts) { + llvm::dbgs() << " "; + dstOpInst->dump(); + }); + // Compute cost of sliced and unsliced src loop nest. SmallVector<ForInst *, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -845,13 +875,27 @@ static bool isFusionProfitable(OperationInst *srcOpInst, // of these bounds). Next the union slice bounds are used to calculate // the cost of the slice and the cost of the slice inserted into the dst // loop nest at 'dstLoopDepth'. - unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max(); - unsigned bestDstLoopDepth; + uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); + uint64_t maxStorageReduction = 0; + Optional<uint64_t> sliceMemEstimate = None; + SmallVector<ComputationSliceState, 4> sliceStates; sliceStates.resize(maxDstLoopDepth); + // The best loop depth at which to materialize the slice. + Optional<unsigned> bestDstLoopDepth = None; + + // Compute op instance count for the src loop nest without iteration slicing. + uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); + + // Compute op instance count for the src loop nest. + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, + /*tripCountOverrideMap=*/nullptr, + /*computeCostMap=*/nullptr); llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap; - DenseMap<ForInst *, uint64_t> computeCostMap; + DenseMap<ForInst *, int64_t> computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { MemRefAccess srcAccess(srcOpInst); // Handle the common case of one dst load without a copy. @@ -872,56 +916,167 @@ static bool isFusionProfitable(OperationInst *srcOpInst, sliceTripCountMap.clear(); if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1], &sliceTripCountMap)) - return false; + // We'll skip cases where we the trip count was non-constant. + continue; - // Compute op instance count for the src loop nest with iteration slicing. - uint64_t sliceComputeCost = - getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap, - /*computeCostMap=*/nullptr); + // Checks whether a store to load forwarding will happen. + int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); + bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); + + assert(sliceIterationCount > 0); + + // Compute cost of fusion for this dest loop depth. - // Compute cost of fusion for these values of 'i' and 'j'. computeCostMap.clear(); + + // The store and loads to this memref will disappear. + if (storeLoadFwdGuaranteed) { + // A single store disappears: -1 for that. + computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1; + for (auto *loadOp : dstOpInsts) { + if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst())) + computeCostMap[loadLoop] = -1; + } + } + + // Compute op instance count for the src loop nest with iteration slicing. + int64_t sliceComputeCost = + getComputeCost(srcLoopIVs[0], &srcLoopNestStats, + /*tripCountOverrideMap=*/&sliceTripCountMap, + /*computeCostMap=*/&computeCostMap); + + // Compute cost of fusion for this depth. computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost; - uint64_t fusedLoopNestComputeCost = + + int64_t fusedLoopNestComputeCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); - if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) { - minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + + double additionalComputeFraction = + fusedLoopNestComputeCost / + (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - + 1; + + // TODO(bondhugula): This is an ugly approximation. Fix this by finding a + // good way to calculate the footprint of the memref in the slice and + // divide it by the total memory footprint of the fused computation. + double storageReduction = + static_cast<double>(srcLoopNestCost) / sliceIterationCount; + + LLVM_DEBUG( + std::stringstream msg; + msg << " evaluating fusion profitability at depth : " << i << "\n" + << std::setprecision(2) << " additional compute fraction: " + << 100.0 * additionalComputeFraction << "%\n" + << " storage reduction factor: " << storageReduction << "x\n" + << " fused nest cost: " << fusedLoopNestComputeCost << "\n" + << " slice iteration count: " << sliceIterationCount << "\n"; + llvm::dbgs() << msg.str()); + + double computeToleranceThreshold = + clFusionAddlComputeTolerance.getNumOccurrences() > 0 + ? clFusionAddlComputeTolerance + : LoopFusion::kComputeToleranceThreshold; + + // TODO(b/123247369): This is a placeholder cost model. + // Among all choices that add an acceptable amount of redundant computation + // (as per computeToleranceThreshold), we will simply pick the one that + // reduces the intermediary size the most. + if ((storageReduction > maxStorageReduction) && + (clMaximalLoopFusion || + (additionalComputeFraction < computeToleranceThreshold))) { + maxStorageReduction = storageReduction; bestDstLoopDepth = i; + minFusedLoopNestComputeCost = fusedLoopNestComputeCost; + // TODO(bondhugula,andydavis): find a good way to compute the memory + // footprint of the materialized slice. + // Approximating this to the compute cost of the slice. This could be an + // under-approximation or an overapproximation, but in many cases + // accurate. + sliceMemEstimate = sliceIterationCount; } } - // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); - // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats, - /*tripCountOverrideMap=*/nullptr, - /*computeCostMap=*/nullptr); + // A simple cost model: fuse if it reduces the memory footprint. If + // -maximal-fusion is set, fuse nevertheless. - LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics " - << " bestDstLoopDepth: " << bestDstLoopDepth - << " srcLoopNestCost: " << srcLoopNestCost - << " dstLoopNestCost: " << dstLoopNestCost - << " minFusedLoopNestComputeCost: " - << minFusedLoopNestComputeCost << "\n"); - - // Do not fuse if fused loop would increase the total cost of the computation, - // unless 'clMaximalLoopFusion' flag is set. - // TODO(andydavis) Use locality/reduction in slice memref size/opportunity - // for load/store forwarding in cost model. - if (!clMaximalLoopFusion && - minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost) + if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << "All fusion choices involve more than the threshold amount of" + "redundant computation; NOT fusing.\n"); return false; + } + + assert(bestDstLoopDepth.hasValue() && + "expected to have a value per logic above"); + + // Set dstLoopDepth based on best values from search. + *dstLoopDepth = bestDstLoopDepth.getValue(); + + LLVM_DEBUG( + llvm::dbgs() << " LoopFusion fusion stats:\n" + << "\n Best loop depth: " << bestDstLoopDepth + << "\n src loop nest compute cost: " << srcLoopNestCost + << "\n dst loop nest compute cost: " << dstLoopNestCost + << "\n fused loop nest compute cost: " + << minFusedLoopNestComputeCost << "\n"); + + auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]); + + Optional<double> storageReduction = None; + + if (!clMaximalLoopFusion) { + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG( + llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } + + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); + + assert(sliceMemEstimate.hasValue() && "expected value"); + // This is an inaccurate estimate since sliceMemEstimate is isaccurate. + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); + + if (fusedMem > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; + } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); + } + + double additionalComputeFraction = + 100.0 * (minFusedLoopNestComputeCost / + (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - + 1); + + std::stringstream msg; + msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " + << setprecision(2) << additionalComputeFraction + << "% redundant computation and a "; + msg << (storageReduction.hasValue() + ? std::to_string(storageReduction.getValue()) + : "<unknown>"); + msg << "% storage reduction.\n"; + LLVM_DEBUG(llvm::dbgs() << msg.str()); + // Update return parameter 'sliceState' with 'bestSliceState'. - ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1]; + ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; sliceState->lbs = bestSliceState->lbs; sliceState->ubs = bestSliceState->ubs; sliceState->lbOperands = bestSliceState->lbOperands; sliceState->ubOperands = bestSliceState->ubOperands; - // Set dstLoopDepth based on best values from search. - *dstLoopDepth = bestDstLoopDepth; + // Canonicalize slice bound affine maps. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { if (sliceState->lbs[i] != AffineMap::Null()) { @@ -1017,29 +1172,35 @@ public: // Skip 'srcEdge' if not for 'memref'. if (srcEdge.memref != memref) continue; + auto *srcNode = mdg->getNode(srcEdge.id); // Skip if 'srcNode' is not a loop nest. if (!isa<ForInst>(srcNode->inst)) continue; + // Skip if 'srcNode' has more than one store to 'memref'. if (srcNode->getStoreOpCount(memref) != 1) continue; + // Skip 'srcNode' if it has in dependence edges. NOTE: This is overly // TODO(andydavis) Track dependence type with edges, and just check // for WAW dependence edge here. if (mdg->getInEdgeCount(srcNode->id, memref) != 0) continue; + // Skip if 'srcNode' has out edges to other memrefs after 'dstId'. if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId) continue; + + // Check if fusion would be profitable. // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); - // Check if fusion would be profitable. unsigned dstLoopDepth; mlir::ComputationSliceState sliceState; if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState, &dstLoopDepth)) continue; + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. auto *sliceLoopNest = mlir::insertBackwardComputationSlice( srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 8e5b706835e..57b5d8dd0ef 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -633,7 +633,6 @@ func @fuse_reshape_16_4_64() { // ----- -// TODO(b/123072438) Re-enable test MemRefRegion bug is fixed. // All three loop nests below (6-d one, 2-d one, 2-d one is fused into a single // 2-d loop nest). // CHECK-LABEL: func @R6_to_R2_reshape @@ -970,24 +969,24 @@ func @should_fuse_deep_loop_nests() { // CHECK-NEXT: for %i1 = 0 to 3 { // CHECK-NEXT: for %i2 = 0 to 2 { // CHECK-NEXT: for %i3 = 0 to 2 { -// CHECK-NEXT: for %i4 = 0 to 16 { -// CHECK-NEXT: for %i5 = 0 to 10 { -// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: for %i6 = 0 to 16 { -// CHECK-NEXT: for %i7 = 0 to 10 { -// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i6, %i7) -// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2> -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: for %i8 = 0 to 3 { -// CHECK-NEXT: for %i9 = 0 to 3 { +// CHECK-NEXT: for %i4 = 0 to 3 { +// CHECK-NEXT: for %i5 = 0 to 3 { +// CHECK-NEXT: for %i6 = 0 to 16 { +// CHECK-NEXT: for %i7 = 0 to 10 { +// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: for %i8 = 0 to 16 { +// CHECK-NEXT: for %i9 = 0 to 10 { +// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9) +// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2> +// CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: for %i10 = 0 to 2 { // CHECK-NEXT: for %i11 = 0 to 2 { // CHECK-NEXT: for %i12 = 0 to 16 { // CHECK-NEXT: for %i13 = 0 to 10 { -// CHECK-NEXT: %5 = load %0[%i10, %i11, %i8, %i9, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> +// CHECK-NEXT: %5 = load %0[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: for %i14 = 0 to 16 { |

