summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp86
-rw-r--r--mlir/lib/Transforms/TestLoopFusion.cpp6
-rw-r--r--mlir/lib/Transforms/Utils/LoopFusionUtils.cpp57
3 files changed, 60 insertions, 89 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 1f475f1fb44..7eb2c7289c0 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1192,82 +1192,6 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
return true;
}
-// Computes the union of all slice bounds computed between 'srcOpInst'
-// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns
-// the union in 'sliceState'. Returns true on success, false otherwise.
-// TODO(andydavis) Move this to a loop fusion utility function.
-static bool getSliceUnion(Operation *srcOpInst,
- ArrayRef<Operation *> dstLoadOpInsts,
- unsigned numSrcLoopIVs, unsigned dstLoopDepth,
- ComputationSliceState *sliceState) {
- MemRefAccess srcAccess(srcOpInst);
- unsigned numDstLoadOpInsts = dstLoadOpInsts.size();
- assert(numDstLoadOpInsts > 0);
- // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'.
- if (failed(mlir::getBackwardComputationSliceState(
- srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth,
- sliceState)))
- return false;
- // Handle the common case of one dst load without a copy.
- if (numDstLoadOpInsts == 1)
- return true;
-
- // Initialize 'sliceUnionCst' with the bounds computed in previous step.
- FlatAffineConstraints sliceUnionCst;
- if (failed(sliceState->getAsConstraints(&sliceUnionCst))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n.");
- return false;
- }
-
- // Compute the union of slice bounds between 'srcOpInst' and each load
- // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'.
- for (unsigned i = 1; i < numDstLoadOpInsts; ++i) {
- MemRefAccess dstAccess(dstLoadOpInsts[i]);
- // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'.
- ComputationSliceState tmpSliceState;
- if (failed(mlir::getBackwardComputationSliceState(
- srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n.");
- return false;
- }
-
- // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
- FlatAffineConstraints tmpSliceCst;
- if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n.");
- return false;
- }
- // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
- if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute union bounding box of slice bounds.\n.");
- return false;
- }
- }
-
- // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
- sliceUnionCst.convertLoopIVSymbolsToDims();
-
- sliceState->clearBounds();
- sliceState->lbs.resize(numSrcLoopIVs, AffineMap());
- sliceState->ubs.resize(numSrcLoopIVs, AffineMap());
-
- // Get slice bounds from slice union constraints 'sliceUnionCst'.
- sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(),
- &sliceState->lbs, &sliceState->ubs);
- // Add slice bound operands of union.
- SmallVector<Value *, 4> sliceBoundOperands;
- sliceUnionCst.getIdValues(numSrcLoopIVs,
- sliceUnionCst.getNumDimAndSymbolIds(),
- &sliceBoundOperands);
- // Give each bound its own copy of 'sliceBoundOperands' for subsequent
- // canonicalization.
- sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
- sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
- return true;
-}
-
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -1404,10 +1328,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
DenseMap<Operation *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
// Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
- if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i,
- &sliceStates[i - 1])) {
+ if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
+ /*dstLoopDepth=*/i,
+ &sliceStates[i - 1]))) {
LLVM_DEBUG(llvm::dbgs()
- << "getSliceUnion failed for loopDepth: " << i << "\n");
+ << "computeSliceUnion failed for loopDepth: " << i << "\n");
continue;
}
@@ -1813,9 +1738,10 @@ public:
continue;
// TODO(andydavis) Remove assert and surrounding code when
// canFuseLoops is fully functional.
+ mlir::ComputationSliceState sliceUnion;
FusionResult result = mlir::canFuseLoops(
cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
- bestDstLoopDepth, /*srcSlice=*/nullptr);
+ bestDstLoopDepth, &sliceUnion);
assert(result.value == FusionResult::Success);
(void)result;
diff --git a/mlir/lib/Transforms/TestLoopFusion.cpp b/mlir/lib/Transforms/TestLoopFusion.cpp
index 9ace2fb4350..638cf915b6a 100644
--- a/mlir/lib/Transforms/TestLoopFusion.cpp
+++ b/mlir/lib/Transforms/TestLoopFusion.cpp
@@ -76,8 +76,10 @@ static void testDependenceCheck(SmallVector<AffineForOp, 2> &loops, unsigned i,
unsigned j, unsigned loopDepth) {
AffineForOp srcForOp = loops[i];
AffineForOp dstForOp = loops[j];
- FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, loopDepth,
- /*srcSlice=*/nullptr);
+ mlir::ComputationSliceState sliceUnion;
+ // TODO(andydavis) Test at deeper loop depths current loop depth + 1.
+ FusionResult result =
+ mlir::canFuseLoops(srcForOp, dstForOp, loopDepth + 1, &sliceUnion);
if (result.value == FusionResult::FailBlockDependence) {
srcForOp.getOperation()->emitRemark("block-level dependence preventing"
" fusion of loop nest ")
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 9de6766e075..cb1d9d17ed0 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -40,9 +40,10 @@
using namespace mlir;
-// Gathers all load and store operations in 'opA' into 'values', where
+// Gathers all load and store memref accesses in 'opA' into 'values', where
// 'values[memref] == true' for each store operation.
-static void getLoadsAndStores(Operation *opA, DenseMap<Value *, bool> &values) {
+static void getLoadAndStoreMemRefAccesses(Operation *opA,
+ DenseMap<Value *, bool> &values) {
opA->walk([&](Operation *op) {
if (auto loadOp = dyn_cast<LoadOp>(op)) {
if (values.count(loadOp.getMemRef()) == 0)
@@ -73,7 +74,7 @@ static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
// Record memref values from all loads/store in loop nest rooted at 'opA'.
// Map from memref value to bool which is true if store, false otherwise.
DenseMap<Value *, bool> values;
- getLoadsAndStores(opA, values);
+ getLoadAndStoreMemRefAccesses(opA, values);
// For each 'opX' in block in range ('opA', 'opB'), check if there is a data
// dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
@@ -99,7 +100,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
// Record memref values from all loads/store in loop nest rooted at 'opB'.
// Map from memref value to bool which is true if store, false otherwise.
DenseMap<Value *, bool> values;
- getLoadsAndStores(opB, values);
+ getLoadAndStoreMemRefAccesses(opB, values);
// For each 'opX' in block in range ('opA', 'opB') in reverse order,
// check if there is a data dependence from 'opX' to 'opB':
@@ -176,8 +177,22 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
return forOpB.getOperation();
}
+// Gathers all load and store ops in loop nest rooted at 'forOp' into
+// 'loadAndStoreOps'.
+static bool
+gatherLoadsAndStores(AffineForOp forOp,
+ SmallVectorImpl<Operation *> &loadAndStoreOps) {
+ bool hasIfOp = false;
+ forOp.getOperation()->walk([&](Operation *op) {
+ if (isa<LoadOp>(op) || isa<StoreOp>(op))
+ loadAndStoreOps.push_back(op);
+ else if (isa<AffineIfOp>(op))
+ hasIfOp = true;
+ });
+ return !hasIfOp;
+}
+
// TODO(andydavis) Add support for the following features in subsequent CLs:
-// *) Computing union of slices computed between src/dst loads and stores.
// *) Compute dependences of unfused src/dst loops.
// *) Compute dependences of src/dst loop as if they were fused.
// *) Check for fusion preventing dependences (e.g. a dependence which changes
@@ -185,18 +200,46 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
unsigned dstLoopDepth,
ComputationSliceState *srcSlice) {
- // Return 'false' if 'srcForOp' and 'dstForOp' are not in the same block.
+ // Return 'failure' if 'dstLoopDepth == 0'.
+ if (dstLoopDepth == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
+ return FusionResult::FailPrecondition;
+ }
+ // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
auto *block = srcForOp.getOperation()->getBlock();
if (block != dstForOp.getOperation()->getBlock()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
return FusionResult::FailPrecondition;
}
- // Return 'false' if no valid insertion point for fused loop nest in 'block'
+ // Return 'failure' if no valid insertion point for fused loop nest in 'block'
// exists which would preserve dependences.
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
return FusionResult::FailBlockDependence;
}
+
+ // Gather all load and store ops in 'srcForOp'.
+ SmallVector<Operation *, 4> srcLoadAndStoreOps;
+ if (!gatherLoadsAndStores(srcForOp, srcLoadAndStoreOps)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Gather all load and store ops in 'dstForOp'.
+ SmallVector<Operation *, 4> dstLoadAndStoreOps;
+ if (!gatherLoadsAndStores(dstForOp, dstLoadAndStoreOps)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Compute union of computation slices computed from all pairs in
+ // {'srcLoadAndStoreOps', 'dstLoadAndStoreOps'}.
+ if (failed(mlir::computeSliceUnion(srcLoadAndStoreOps, dstLoadAndStoreOps,
+ dstLoopDepth, srcSlice))) {
+ LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+ return FusionResult::FailPrecondition;
+ }
+
return FusionResult::Success;
}
OpenPOWER on IntegriCloud