summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopFusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp245
1 files changed, 203 insertions, 42 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 520b89ded48..239915b1d4b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -39,6 +39,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include <iomanip>
#define DEBUG_TYPE "loop-fusion"
@@ -46,9 +47,16 @@ using llvm::SetVector;
using namespace mlir;
+/// Disables fusion profitability check and fuses if valid.
static llvm::cl::opt<bool>
clMaximalLoopFusion("fusion-maximal", llvm::cl::Hidden,
- llvm::cl::desc("Enables maximal loop fusion."));
+ llvm::cl::desc("Enables maximal loop fusion"));
+
+/// A threshold in percent of additional computation allowed when fusing.
+static llvm::cl::opt<double> clFusionAddlComputeTolerance(
+ "fusion-compute-tolerance", llvm::cl::Hidden,
+ llvm::cl::desc("Fractional increase in additional"
+ "computation tolerated while fusing"));
namespace {
@@ -66,6 +74,10 @@ struct LoopFusion : public FunctionPass {
PassResult runOnFunction(Function *f) override;
static char passID;
+
+ // The amount of additional computation that is tolerated while fusing
+ // pair-wise as a fraction of the total computation.
+ constexpr static double kComputeToleranceThreshold = 0.30f;
};
} // end anonymous namespace
@@ -496,12 +508,12 @@ public:
// inserting a sliced loop nest of known cost into the loop's body.
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
// within another loop.
-static uint64_t getComputeCost(
+static int64_t getComputeCost(
ForInst *forInst, LoopNestStats *stats,
llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
- DenseMap<ForInst *, uint64_t> *computeCostMap) {
+ DenseMap<ForInst *, int64_t> *computeCostMap) {
// 'opCount' is the total number operations in one iteration of 'forInst' body
- uint64_t opCount = stats->opCountMap[forInst];
+ int64_t opCount = stats->opCountMap[forInst];
if (stats->loopMap.count(forInst) > 0) {
for (auto *childForInst : stats->loopMap[forInst]) {
opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
@@ -516,7 +528,7 @@ static uint64_t getComputeCost(
}
}
// Override trip count (if specified in map).
- uint64_t tripCount = stats->tripCountMap[forInst];
+ int64_t tripCount = stats->tripCountMap[forInst];
if (tripCountOverrideMap != nullptr) {
auto it = tripCountOverrideMap->find(forInst);
if (it != tripCountOverrideMap->end()) {
@@ -777,6 +789,16 @@ static Value *createPrivateMemRef(ForInst *forInst,
return newMemRef;
}
+// Does the slice have a single iteration?
+static uint64_t getSliceIterationCount(
+ const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
+ uint64_t iterCount = 1;
+ for (const auto &count : sliceTripCountMap) {
+ iterCount *= count.second;
+ }
+ return iterCount;
+}
+
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstOpInsts'.
// Returns true if it profitable to fuse the candidate loop nests. Returns
@@ -810,6 +832,14 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
ArrayRef<OperationInst *> dstOpInsts,
ComputationSliceState *sliceState,
unsigned *dstLoopDepth) {
+ LLVM_DEBUG(llvm::dbgs() << "Checking whether fusion is profitable between:\n";
+ llvm::dbgs() << " "; srcOpInst->dump(); llvm::dbgs() << " and \n";
+ for (auto dstOpInst
+ : dstOpInsts) {
+ llvm::dbgs() << " ";
+ dstOpInst->dump();
+ });
+
// Compute cost of sliced and unsliced src loop nest.
SmallVector<ForInst *, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
@@ -845,13 +875,27 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
// of these bounds). Next the union slice bounds are used to calculate
// the cost of the slice and the cost of the slice inserted into the dst
// loop nest at 'dstLoopDepth'.
- unsigned minFusedLoopNestComputeCost = std::numeric_limits<unsigned>::max();
- unsigned bestDstLoopDepth;
+ uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
+ uint64_t maxStorageReduction = 0;
+ Optional<uint64_t> sliceMemEstimate = None;
+
SmallVector<ComputationSliceState, 4> sliceStates;
sliceStates.resize(maxDstLoopDepth);
+ // The best loop depth at which to materialize the slice.
+ Optional<unsigned> bestDstLoopDepth = None;
+
+ // Compute op instance count for the src loop nest without iteration slicing.
+ uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
+
+ // Compute op instance count for the src loop nest.
+ uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
- DenseMap<ForInst *, uint64_t> computeCostMap;
+ DenseMap<ForInst *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MemRefAccess srcAccess(srcOpInst);
// Handle the common case of one dst load without a copy.
@@ -872,56 +916,167 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
sliceTripCountMap.clear();
if (!buildSliceTripCountMap(srcOpInst, &sliceStates[i - 1],
&sliceTripCountMap))
- return false;
+ // We'll skip cases where we the trip count was non-constant.
+ continue;
- // Compute op instance count for the src loop nest with iteration slicing.
- uint64_t sliceComputeCost =
- getComputeCost(srcLoopIVs[0], &srcLoopNestStats, &sliceTripCountMap,
- /*computeCostMap=*/nullptr);
+ // Checks whether a store to load forwarding will happen.
+ int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
+ bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
+
+ assert(sliceIterationCount > 0);
+
+ // Compute cost of fusion for this dest loop depth.
- // Compute cost of fusion for these values of 'i' and 'j'.
computeCostMap.clear();
+
+ // The store and loads to this memref will disappear.
+ if (storeLoadFwdGuaranteed) {
+ // A single store disappears: -1 for that.
+ computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
+ for (auto *loadOp : dstOpInsts) {
+ if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
+ computeCostMap[loadLoop] = -1;
+ }
+ }
+
+ // Compute op instance count for the src loop nest with iteration slicing.
+ int64_t sliceComputeCost =
+ getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
+ /*tripCountOverrideMap=*/&sliceTripCountMap,
+ /*computeCostMap=*/&computeCostMap);
+
+ // Compute cost of fusion for this depth.
computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
- uint64_t fusedLoopNestComputeCost =
+
+ int64_t fusedLoopNestComputeCost =
getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
- if (fusedLoopNestComputeCost < minFusedLoopNestComputeCost) {
- minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
+
+ double additionalComputeFraction =
+ fusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1;
+
+ // TODO(bondhugula): This is an ugly approximation. Fix this by finding a
+ // good way to calculate the footprint of the memref in the slice and
+ // divide it by the total memory footprint of the fused computation.
+ double storageReduction =
+ static_cast<double>(srcLoopNestCost) / sliceIterationCount;
+
+ LLVM_DEBUG(
+ std::stringstream msg;
+ msg << " evaluating fusion profitability at depth : " << i << "\n"
+ << std::setprecision(2) << " additional compute fraction: "
+ << 100.0 * additionalComputeFraction << "%\n"
+ << " storage reduction factor: " << storageReduction << "x\n"
+ << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
+ << " slice iteration count: " << sliceIterationCount << "\n";
+ llvm::dbgs() << msg.str());
+
+ double computeToleranceThreshold =
+ clFusionAddlComputeTolerance.getNumOccurrences() > 0
+ ? clFusionAddlComputeTolerance
+ : LoopFusion::kComputeToleranceThreshold;
+
+ // TODO(b/123247369): This is a placeholder cost model.
+ // Among all choices that add an acceptable amount of redundant computation
+ // (as per computeToleranceThreshold), we will simply pick the one that
+ // reduces the intermediary size the most.
+ if ((storageReduction > maxStorageReduction) &&
+ (clMaximalLoopFusion ||
+ (additionalComputeFraction < computeToleranceThreshold))) {
+ maxStorageReduction = storageReduction;
bestDstLoopDepth = i;
+ minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
+ // TODO(bondhugula,andydavis): find a good way to compute the memory
+ // footprint of the materialized slice.
+ // Approximating this to the compute cost of the slice. This could be an
+ // under-approximation or an overapproximation, but in many cases
+ // accurate.
+ sliceMemEstimate = sliceIterationCount;
}
}
- // Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
- // Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ // A simple cost model: fuse if it reduces the memory footprint. If
+ // -maximal-fusion is set, fuse nevertheless.
- LLVM_DEBUG(llvm::dbgs() << "LoopFusion statistics "
- << " bestDstLoopDepth: " << bestDstLoopDepth
- << " srcLoopNestCost: " << srcLoopNestCost
- << " dstLoopNestCost: " << dstLoopNestCost
- << " minFusedLoopNestComputeCost: "
- << minFusedLoopNestComputeCost << "\n");
-
- // Do not fuse if fused loop would increase the total cost of the computation,
- // unless 'clMaximalLoopFusion' flag is set.
- // TODO(andydavis) Use locality/reduction in slice memref size/opportunity
- // for load/store forwarding in cost model.
- if (!clMaximalLoopFusion &&
- minFusedLoopNestComputeCost > srcLoopNestCost + dstLoopNestCost)
+ if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "All fusion choices involve more than the threshold amount of"
+ "redundant computation; NOT fusing.\n");
return false;
+ }
+
+ assert(bestDstLoopDepth.hasValue() &&
+ "expected to have a value per logic above");
+
+ // Set dstLoopDepth based on best values from search.
+ *dstLoopDepth = bestDstLoopDepth.getValue();
+
+ LLVM_DEBUG(
+ llvm::dbgs() << " LoopFusion fusion stats:\n"
+ << "\n Best loop depth: " << bestDstLoopDepth
+ << "\n src loop nest compute cost: " << srcLoopNestCost
+ << "\n dst loop nest compute cost: " << dstLoopNestCost
+ << "\n fused loop nest compute cost: "
+ << minFusedLoopNestComputeCost << "\n");
+
+ auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
+ auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
+
+ Optional<double> storageReduction = None;
+
+ if (!clMaximalLoopFusion) {
+ if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ return false;
+ }
+
+ auto srcMemSizeVal = srcMemSize.getValue();
+ auto dstMemSizeVal = dstMemSize.getValue();
+
+ assert(sliceMemEstimate.hasValue() && "expected value");
+ // This is an inaccurate estimate since sliceMemEstimate is isaccurate.
+ auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
+
+ LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
+ << " dst mem: " << dstMemSizeVal << "\n"
+ << " fused mem: " << fusedMem << "\n"
+ << " slice mem: " << sliceMemEstimate << "\n");
+
+ if (fusedMem > srcMemSizeVal + dstMemSizeVal) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ return false;
+ }
+ storageReduction =
+ 100.0 *
+ (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
+ }
+
+ double additionalComputeFraction =
+ 100.0 * (minFusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1);
+
+ std::stringstream msg;
+ msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
+ << setprecision(2) << additionalComputeFraction
+ << "% redundant computation and a ";
+ msg << (storageReduction.hasValue()
+ ? std::to_string(storageReduction.getValue())
+ : "<unknown>");
+ msg << "% storage reduction.\n";
+ LLVM_DEBUG(llvm::dbgs() << msg.str());
+
// Update return parameter 'sliceState' with 'bestSliceState'.
- ComputationSliceState *bestSliceState = &sliceStates[bestDstLoopDepth - 1];
+ ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
sliceState->lbs = bestSliceState->lbs;
sliceState->ubs = bestSliceState->ubs;
sliceState->lbOperands = bestSliceState->lbOperands;
sliceState->ubOperands = bestSliceState->ubOperands;
- // Set dstLoopDepth based on best values from search.
- *dstLoopDepth = bestDstLoopDepth;
+
// Canonicalize slice bound affine maps.
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
if (sliceState->lbs[i] != AffineMap::Null()) {
@@ -1017,29 +1172,35 @@ public:
// Skip 'srcEdge' if not for 'memref'.
if (srcEdge.memref != memref)
continue;
+
auto *srcNode = mdg->getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
if (!isa<ForInst>(srcNode->inst))
continue;
+
// Skip if 'srcNode' has more than one store to 'memref'.
if (srcNode->getStoreOpCount(memref) != 1)
continue;
+
// Skip 'srcNode' if it has in dependence edges. NOTE: This is overly
// TODO(andydavis) Track dependence type with edges, and just check
// for WAW dependence edge here.
if (mdg->getInEdgeCount(srcNode->id, memref) != 0)
continue;
+
// Skip if 'srcNode' has out edges to other memrefs after 'dstId'.
if (mdg->getMinOutEdgeNodeId(srcNode->id, memref) < dstId)
continue;
+
+ // Check if fusion would be profitable.
// Get unique 'srcNode' store op.
auto *srcStoreOpInst = srcNode->stores.front();
- // Check if fusion would be profitable.
unsigned dstLoopDepth;
mlir::ComputationSliceState sliceState;
if (!isFusionProfitable(srcStoreOpInst, dstLoadOpInsts, &sliceState,
&dstLoopDepth))
continue;
+
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], dstLoopDepth, &sliceState);
OpenPOWER on IntegriCloud