summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/Utils.h11
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp11
-rw-r--r--mlir/lib/Analysis/Utils.cpp107
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp3
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp245
-rw-r--r--mlir/test/Transforms/loop-fusion.mlir29
6 files changed, 328 insertions, 78 deletions
diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index bfdf4d40b34..8c8f73da409 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -76,9 +76,10 @@ struct MemRefRegion {
void setWrite(bool flag) { write = flag; }
/// Returns a constant upper bound on the number of elements in this region if
- /// bounded by a known constant, None otherwise. The 'shape' vector is set to
- /// the corresponding dimension-wise bounds major to minor. We use int64_t
- /// instead of uint64_t since index types can be at most int64_t.
+ /// bounded by a known constant (always possible for static shapes), None
+ /// otherwise. The 'shape' vector is set to the corresponding dimension-wise
+ /// bounds major to minor. We use int64_t instead of uint64_t since index
+ /// types can be at most int64_t.
Optional<int64_t> getConstantBoundingSizeAndShape(
SmallVectorImpl<int> *shape = nullptr,
std::vector<SmallVector<int64_t, 4>> *lbs = nullptr,
@@ -192,6 +193,10 @@ ForInst *insertBackwardComputationSlice(OperationInst *srcOpInst,
OperationInst *dstOpInst,
unsigned dstLoopDepth,
ComputationSliceState *sliceState);
+
+Optional<int64_t> getMemoryFootprintBytes(const ForInst &forInst,
+ int memorySpace = -1);
+
} // end namespace mlir
#endif // MLIR_ANALYSIS_UTILS_H
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 44daaf1459b..baab283ab25 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -395,7 +395,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
FlatAffineConstraints cst;
if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) {
LLVM_DEBUG(llvm::dbgs()
- << "composition unimplemented for semi-affine maps");
+ << "composition unimplemented for semi-affine maps\n");
return false;
}
assert(flatExprs.size() == vMap->getNumResults());
@@ -823,6 +823,9 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
if (posStart >= posLimit)
return 0;
+ LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", "
+ << posLimit << ")\n");
+
GCDTightenInequalities();
unsigned pivotCol = 0;
@@ -1749,6 +1752,9 @@ getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
return {newNumDims, newNumSymbols};
}
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "fm"
+
/// Eliminates identifier at the specified position using Fourier-Motzkin
/// variable elimination. This technique is exact for rational spaces but
/// conservative (in "rare" cases) for integer spaces. The operation corresponds
@@ -1951,6 +1957,9 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
LLVM_DEBUG(dump());
}
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "affine-structures"
+
void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
if (num == 0)
return;
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 592fad4ab29..79d1b696612 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -60,7 +60,8 @@ Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
SmallVectorImpl<int64_t> *lbDivisors) const {
auto memRefType = memref->getType().cast<MemRefType>();
unsigned rank = memRefType.getRank();
- shape->reserve(rank);
+ if (shape)
+ shape->reserve(rank);
// Find a constant upper bound on the extent of this memref region along each
// dimension.
@@ -189,6 +190,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
// Add access function equalities to connect loop IVs to data dimensions.
if (!regionCst->composeMap(&accessValueMap)) {
LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n");
+ LLVM_DEBUG(accessValueMap.getAffineMap().dump());
return false;
}
@@ -207,14 +209,13 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
}
// Project out any local variables (these would have been added for any
// mod/divs).
- regionCst->projectOut(regionCst->getNumDimIds() +
- regionCst->getNumSymbolIds(),
+ regionCst->projectOut(regionCst->getNumDimAndSymbolIds(),
regionCst->getNumLocalIds());
// Set all identifiers appearing after the first 'rank' identifiers as
// symbolic identifiers - so that the ones correspoding to the memref
// dimensions are the dimensional identifiers for the memref region.
- regionCst->setDimSymbolSeparation(regionCst->getNumIds() - rank);
+ regionCst->setDimSymbolSeparation(regionCst->getNumDimAndSymbolIds() - rank);
// Constant fold any symbolic identifiers.
regionCst->constantFoldIdRange(/*pos=*/regionCst->getNumDimIds(),
@@ -222,12 +223,31 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
assert(regionCst->getNumDimIds() == rank && "unexpected MemRefRegion format");
+ LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
+ LLVM_DEBUG(region->getConstraints()->dump());
+
return true;
}
+// TODO(mlir-team): improve/complete this when we have target data.
+static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+ auto elementType = memRefType.getElementType();
+
+ unsigned sizeInBits;
+ if (elementType.isIntOrFloat()) {
+ sizeInBits = elementType.getIntOrFloatBitWidth();
+ } else {
+ auto vectorType = elementType.cast<VectorType>();
+ sizeInBits =
+ vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+ }
+ return llvm::divideCeil(sizeInBits, 8);
+}
+
/// Returns the size of memref data in bytes if it's statically shaped, None
/// otherwise. If the element of the memref has vector type, takes into account
/// size of the vector as well.
+// TODO(mlir-team): improve/complete this when we have target data.
Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
if (memRefType.getNumDynamicDims() > 0)
return None;
@@ -235,18 +255,11 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
return None;
- uint64_t sizeInBits;
- if (elementType.isIntOrFloat()) {
- sizeInBits = elementType.getIntOrFloatBitWidth();
- } else {
- auto vectorType = elementType.cast<VectorType>();
- sizeInBits =
- vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
- }
+ unsigned sizeInBytes = getMemRefEltSizeInBytes(memRefType);
for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
- sizeInBits = sizeInBits * memRefType.getDimSize(i);
+ sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
}
- return llvm::divideCeil(sizeInBits, 8);
+ return sizeInBytes;
}
template <typename LoadOrStoreOpPointer>
@@ -525,3 +538,69 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A,
}
return numCommonLoops;
}
+
+// Returns the size of the region.
+static Optional<int64_t> getRegionSize(const MemRefRegion &region) {
+ auto *memref = region.memref;
+ auto memRefType = memref->getType().cast<MemRefType>();
+
+ auto layoutMaps = memRefType.getAffineMaps();
+ if (layoutMaps.size() > 1 ||
+ (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+ LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ return false;
+ }
+
+ // Indices to use for the DmaStart op.
+ // Indices for the original memref being DMAed from/to.
+ SmallVector<Value *, 4> memIndices;
+ // Indices for the faster buffer being DMAed into/from.
+ SmallVector<Value *, 4> bufIndices;
+
+ // Compute the extents of the buffer.
+ Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape();
+ if (!numElements.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+ return None;
+ }
+ return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
+}
+
+Optional<int64_t> mlir::getMemoryFootprintBytes(const ForInst &forInst,
+ int memorySpace) {
+ std::vector<std::unique_ptr<MemRefRegion>> regions;
+
+ // Walk this 'for' instruction to gather all memory regions.
+ bool error = false;
+ const_cast<ForInst *>(&forInst)->walkOps([&](OperationInst *opInst) {
+ if (!opInst->isa<LoadOp>() && !opInst->isa<StoreOp>()) {
+ // Neither load nor a store op.
+ return;
+ }
+
+ // TODO(bondhugula): eventually, we need to be performing a union across
+ // all regions for a given memref instead of creating one region per
+ // memory op. This way we would be allocating O(num of memref's) sets
+ // instead of O(num of load/store op's).
+ auto region = std::make_unique<MemRefRegion>();
+ if (!getMemRefRegion(opInst, 0, region.get())) {
+ LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n");
+ // TODO: stop the walk if an error occurred.
+ error = true;
+ return;
+ }
+ regions.push_back(std::move(region));
+ });
+
+ if (error)
+ return None;
+
+ int64_t totalSizeInBytes = 0;
+ for (const auto &region : regions) {
+ auto size = getRegionSize(*region);
+ if (!size.hasValue())
+ return None;
+ totalSizeInBytes += size.getValue();
+ }
+ return totalSizeInBytes;
+}
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 3b829fc55e5..8b86056c8a9 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -179,18 +179,15 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
&fastBufferShape, &lbs, &lbDivisors);
if (!numElements.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
- *sizeInBytes = 0;
return false;
}
if (numElements.getValue() == 0) {
LLVM_DEBUG(llvm::dbgs() << "Nothing to DMA\n");
- *sizeInBytes = 0;
return false;
}
const FlatAffineConstraints *cst = region.getConstraints();
-
// 'outerIVs' holds the values that this memory region is symbolic/paramteric
// on; this would correspond to loop IVs surrounding the level at which the
// DMA generation is being done.
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 520b89ded48..239915b1d4b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -39,6 +39,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include <iomanip>
#define DEBUG_TYPE "loop-fusion"
@@ -46,9 +47,16 @@ using llvm::SetVector;
using namespace mlir;
+/// Disables fusion profitability check and fuses if valid.
static llvm::cl::opt<bool>
clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
- llvm::cl::desc("Enables maximal loop fusion."));
+ llvm::cl::desc("Enables maximal loop fusion"));
+
+/// A threshold in percent of additional computation allowed when fusing.
+static llvm::cl::opt<double> clFusionAddlComputeTolerance(
+ "fusion-compute-tolerance", llvm::cl::Hidden,
+ llvm::cl::desc("Fractional increase in additional"
+ "computation tolerated while fusing"));
namespace {
@@ -66,6 +74,10 @@ struct LoopFusion : public FunctionPass {
PassResult runOnFunction(Function *f) override;
static char passID;
+
+ // The amount of additional computation that is tolerated while fusing
+ // pair-wise as a fraction of the total computation.
+ constexpr static double kComputeToleranceThreshold = 0.30f;
};
} // end anonymous namespace
@@ -496,12 +508,12 @@ public:
// inserting a sliced loop nest of known cost into the loop's body.
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
// within another loop.
-static uint64_t getComputeCost(
+static int64_t getComputeCost(
ForInst *forInst, LoopNestStats *stats,
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
- DenseMap<ForInst *, uint64_t> *computeCostMap) {
+ DenseMap<ForInst *, int64_t> *computeCostMap) {
// 'opCount' is the total number operations in one iteration of 'forInst' body
- uint64_t opCount = stats->opCountMap[forInst];
+ int64_t opCount = stats->opCountMap[forInst];
if (stats->loopMap.count(forInst) > 0) {
for (auto *childForInst : stats->loopMap[forInst]) {
opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
@@ -516,7 +528,7 @@ static uint64_t getComputeCost(
}
}
// Override trip count (if specified in map).
- uint64_t tripCount = stats->tripCountMap[forInst];
+ int64_t tripCount = stats->tripCountMap[forInst];
if (tripCountOverrideMap != nullptr) {
auto it = tripCountOverrideMap->find(forInst);
if (it != tripCountOverrideMap->end()) {
@@ -777,6 +789,16 @@ static Value *createPrivateMemRef(ForInst *forInst,
return newMemRef;
}
+// Does the slice have a single iteration?
+static uint64_t getSliceIterationCount(
+ const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
+ uint64_t iterCount = 1;
+ for (const auto &count : sliceTripCountMap) {
+ iterCount *= count.second;
+ }
+ return iterCount;
+}
+
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
// Returns true if it profitable to fuse the candidate loop nests. Returns
@@ -810,6 +832,14 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
ArrayRef<OperationInst *> dstOpInsts,
ComputationSliceState *sliceState,
unsigned *dstLoopDepth) {
+ LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n";
+ llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n";
+ for (auto dstOpInst
+ : dstOpInsts) {
+ llvm::dbgs() << " ";
+ dstOpInst->dump();
+ });
+
// Compute cost of sliced and unsliced src loop nest.
SmallVector<ForInst *, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
@@ -845,13 +875,27 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
// of these bounds). Next the union slice bounds are used to calculate
// the cost of the slice and the cost of the slice inserted into the dst
// loop nest at 'dstLoopDepth'.
- unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
- unsigned bestDstLoopDepth;
+ uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
+ uint64_t maxStorageReduction = 0;
+ Optional<uint64_t> sliceMemEstimate = None;
+
SmallVector<ComputationSliceState, 4> sliceStates;
sliceStates.resize(maxDstLoopDepth);
+ // The best loop depth at which to materialize the slice.
+ Optional<unsigned> bestDstLoopDepth = None;
+
+ // Compute op instance count for the src loop nest without iteration slicing.
+ uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
+
+ // Compute op instance count for the src loop nest.
+ uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
- DenseMap<ForInst *, uint64_t> computeCostMap;
+ DenseMap<ForInst *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MemRefAccess srcAccess(srcOpInst);
// Handle the common case of one dst load without a copy.
@@ -872,56 +916,167 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
sliceTripCountMap.clear();
if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
&sliceTripCountMap))
- return false;
+ // We'll skip cases where we the trip count was non-constant.
+ continue;
- // Compute op instance count for the src loop nest with iteration slicing.
- uint64_t sliceComputeCost =
- getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap,
- /*computeCostMap=*/nullptr);
+ // Checks whether a store to load forwarding will happen.
+ int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
+ bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
+
+ assert(sliceIterationCount > 0);
+
+ // Compute cost of fusion for this dest loop depth.
- // Compute cost of fusion for these values of 'i' and 'j'.
computeCostMap.clear();
+
+ // The store and loads to this memref will disappear.
+ if (storeLoadFwdGuaranteed) {
+ // A single store disappears: -1 for that.
+ computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
+ for (auto *loadOp : dstOpInsts) {
+ if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
+ computeCostMap[loadLoop] = -1;
+ }
+ }
+
+ // Compute op instance count for the src loop nest with iteration slicing.
+ int64_t sliceComputeCost =
+ getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
+ /*tripCountOverrideMap=*/&sliceTripCountMap,
+ /*computeCostMap=*/&computeCostMap);
+
+ // Compute cost of fusion for this depth.
computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
- uint64_t fusedLoopNestComputeCost =
+
+ int64_t fusedLoopNestComputeCost =
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
- if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
- minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
+
+ double additionalComputeFraction =
+ fusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1;
+
+ // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
+ // good way to calculate the footprint of the memref in the slice and
+ // divide it by the total memory footprint of the fused computation.
+ double storageReduction =
+ static_cast<double>(srcLoopNestCost) / sliceIterationCount;
+
+ LLVM_DEBUG(
+ std::stringstream msg;
+ msg << " evaluating fusion profitability at depth : " << i << "\n"
+ << std::setprecision(2) << " additional compute fraction: "
+ << 100.0 * additionalComputeFraction << "%\n"
+ << " storage reduction factor: " << storageReduction << "x\n"
+ << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
+ << " slice iteration count: " << sliceIterationCount << "\n";
+ llvm::dbgs() << msg.str());
+
+ double computeToleranceThreshold =
+ clFusionAddlComputeTolerance.getNumOccurrences() > 0
+ ? clFusionAddlComputeTolerance
+ : LoopFusion::kComputeToleranceThreshold;
+
+ // TODO(b/123247369): This is a placeholder cost model.
+ // Among all choices that add an acceptable amount of redundant computation
+ // (as per computeToleranceThreshold), we will simply pick the one that
+ // reduces the intermediary size the most.
+ if ((storageReduction > maxStorageReduction) &&
+ (clMaximalLoopFusion ||
+ (additionalComputeFraction < computeToleranceThreshold))) {
+ maxStorageReduction = storageReduction;
bestDstLoopDepth = i;
+ minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
+ // TODO(bondhugula,andydavis): find a good way to compute the memory
+ // footprint of the materialized slice.
+ // Approximating this to the compute cost of the slice. This could be an
+ // under-approximation or an overapproximation, but in many cases
+ // accurate.
+ sliceMemEstimate = sliceIterationCount;
}
}
- // Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
- // Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ // A simple cost model: fuse if it reduces the memory footprint. If
+ // -maximal-fusion is set, fuse nevertheless.
- LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
- << " bestDstLoopDepth: " << bestDstLoopDepth
- << " srcLoopNestCost: " << srcLoopNestCost
- << " dstLoopNestCost: " << dstLoopNestCost
- << " minFusedLoopNestComputeCost: "
- << minFusedLoopNestComputeCost << "\n");
-
- // Do not fuse if fused loop would increase the total cost of the computation,
- // unless 'clMaximalLoopFusion' flag is set.
- // TODO(andydavis) Use locality/reduction in slice memref size/opportunity
- // for load/store forwarding in cost model.
- if (!clMaximalLoopFusion &&
- minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
+ if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "All fusion choices involve more than the threshold amount of"
+ "redundant computation; NOT fusing.\n");
return false;
+ }
+
+ assert(bestDstLoopDepth.hasValue() &&
+ "expected to have a value per logic above");
+
+ // Set dstLoopDepth based on best values from search.
+ *dstLoopDepth = bestDstLoopDepth.getValue();
+
+ LLVM_DEBUG(
+ llvm::dbgs() << " LoopFusion fusion stats:\n"
+ << "\n Best loop depth: " << bestDstLoopDepth
+ << "\n src loop nest compute cost: " << srcLoopNestCost
+ << "\n dst loop nest compute cost: " << dstLoopNestCost
+ << "\n fused loop nest compute cost: "
+ << minFusedLoopNestComputeCost << "\n");
+
+ auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
+ auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
+
+ Optional<double> storageReduction = None;
+
+ if (!clMaximalLoopFusion) {
+ if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ return false;
+ }
+
+ auto srcMemSizeVal = srcMemSize.getValue();
+ auto dstMemSizeVal = dstMemSize.getValue();
+
+ assert(sliceMemEstimate.hasValue() && "expected value");
+ // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
+ auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
+
+ LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
+ << " dst mem: " << dstMemSizeVal << "\n"
+ << " fused mem: " << fusedMem << "\n"
+ << " slice mem: " << sliceMemEstimate << "\n");
+
+ if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ return false;
+ }
+ storageReduction =
+ 100.0 *
+ (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
+ }
+
+ double additionalComputeFraction =
+ 100.0 * (minFusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1);
+
+ std::stringstream msg;
+ msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
+ << setprecision(2) << additionalComputeFraction
+ << "% redundant computation and a ";
+ msg << (storageReduction.hasValue()
+ ? std::to_string(storageReduction.getValue())
+ : "<unknown>");
+ msg << "% storage reduction.\n";
+ LLVM_DEBUG(llvm::dbgs() << msg.str());
+
// Update return parameter 'sliceState' with 'bestSliceState'.
- ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1];
+ ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
sliceState->lbs = bestSliceState->lbs;
sliceState->ubs = bestSliceState->ubs;
sliceState->lbOperands = bestSliceState->lbOperands;
sliceState->ubOperands = bestSliceState->ubOperands;
- // Set dstLoopDepth based on best values from search.
- *dstLoopDepth = bestDstLoopDepth;
+
// Canonicalize slice bound affine maps.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
if (sliceState->lbs[i] != AffineMap::Null()) {
@@ -1017,29 +1172,35 @@ public:
// Skip 'srcEdge' if not for 'memref'.
if (srcEdge.memref != memref)
continue;
+
auto *srcNode = mdg->getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
if (!isa<ForInst>(srcNode->inst))
continue;
+
// Skip if 'srcNode' has more than one store to 'memref'.
if (srcNode->getStoreOpCount(memref) != 1)
continue;
+
// Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
// TODO(andydavis) Track dependence type with edges, and just check
// for WAW dependence edge here.
if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
continue;
+
// Skip if 'srcNode' has out edges to other memrefs after 'dstId'.
if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId)
continue;
+
+ // Check if fusion would be profitable.
// Get unique 'srcNode' store op.
auto *srcStoreOpInst = srcNode->stores.front();
- // Check if fusion would be profitable.
unsigned dstLoopDepth;
mlir::ComputationSliceState sliceState;
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
&dstLoopDepth))
continue;
+
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 8e5b706835e..57b5d8dd0ef 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -633,7 +633,6 @@ func @fuse_reshape_16_4_64() {
// -----
-// TODO(b/123072438) Re-enable test MemRefRegion bug is fixed.
// All three loop nests below (6-d one, 2-d one, 2-d one is fused into a single
// 2-d loop nest).
// CHECK-LABEL: func @R6_to_R2_reshape
@@ -970,24 +969,24 @@ func @should_fuse_deep_loop_nests() {
// CHECK-NEXT: for %i1 = 0 to 3 {
// CHECK-NEXT: for %i2 = 0 to 2 {
// CHECK-NEXT: for %i3 = 0 to 2 {
-// CHECK-NEXT: for %i4 = 0 to 16 {
-// CHECK-NEXT: for %i5 = 0 to 10 {
-// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i4, %i5] : memref<2x2x3x3x16x10xf32, 2>
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: for %i6 = 0 to 16 {
-// CHECK-NEXT: for %i7 = 0 to 10 {
-// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i6, %i7)
-// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2>
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: for %i8 = 0 to 3 {
-// CHECK-NEXT: for %i9 = 0 to 3 {
+// CHECK-NEXT: for %i4 = 0 to 3 {
+// CHECK-NEXT: for %i5 = 0 to 3 {
+// CHECK-NEXT: for %i6 = 0 to 16 {
+// CHECK-NEXT: for %i7 = 0 to 10 {
+// CHECK-NEXT: %3 = load %0[%i2, %i3, %i0, %i1, %i6, %i7] : memref<2x2x3x3x16x10xf32, 2>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: for %i8 = 0 to 16 {
+// CHECK-NEXT: for %i9 = 0 to 10 {
+// CHECK-NEXT: %4 = affine_apply #map0(%i2, %i3, %i0, %i1, %i2, %i3, %i0, %i1, %i8, %i9)
+// CHECK-NEXT: store %cst, %2[%4#0, %4#1, %4#2, %4#3, %4#4, %4#5] : memref<1x1x1x1x16x10xf32, 2>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
// CHECK-NEXT: for %i10 = 0 to 2 {
// CHECK-NEXT: for %i11 = 0 to 2 {
// CHECK-NEXT: for %i12 = 0 to 16 {
// CHECK-NEXT: for %i13 = 0 to 10 {
-// CHECK-NEXT: %5 = load %0[%i10, %i11, %i8, %i9, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2>
+// CHECK-NEXT: %5 = load %0[%i10, %i11, %i4, %i5, %i12, %i13] : memref<2x2x3x3x16x10xf32, 2>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: for %i14 = 0 to 16 {
OpenPOWER on IntegriCloud