summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/LoopFusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp200
1 files changed, 106 insertions, 94 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index fa0e3b51de3..7d4ff03e306 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -97,15 +97,15 @@ namespace {
// operations, and whether or not an IfInst was encountered in the loop nest.
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
- SmallVector<ForInst *, 4> forInsts;
+ SmallVector<OpPointer<AffineForOp>, 4> forOps;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
bool hasNonForRegion = false;
- void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
-
void visitOperationInst(OperationInst *opInst) {
- if (opInst->getNumBlockLists() != 0)
+ if (opInst->isa<AffineForOp>())
+ forOps.push_back(opInst->cast<AffineForOp>());
+ else if (opInst->getNumBlockLists() != 0)
hasNonForRegion = true;
else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
@@ -491,14 +491,14 @@ bool MemRefDependenceGraph::init(Function *f) {
if (f->getBlocks().size() != 1)
return false;
- DenseMap<ForInst *, unsigned> forToNodeMap;
+ DenseMap<Instruction *, unsigned> forToNodeMap;
for (auto &inst : f->front()) {
- if (auto *forInst = dyn_cast<ForInst>(&inst)) {
- // Create graph node 'id' to represent top-level 'forInst' and record
+ if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
+ // Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
- collector.walkForInst(forInst);
- // Return false if IfInsts are found (not currently supported).
+ collector.walk(&inst);
+ // Return false if a non 'for' region was found (not currently supported).
if (collector.hasNonForRegion)
return false;
Node node(nextNodeId++, &inst);
@@ -512,10 +512,9 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
- forToNodeMap[forInst] = node.id;
+ forToNodeMap[&inst] = node.id;
nodes.insert({node.id, node});
- }
- if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ } else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &inst);
@@ -552,12 +551,12 @@ bool MemRefDependenceGraph::init(Function *f) {
for (auto *value : opInst->getResults()) {
for (auto &use : value->getUses()) {
auto *userOpInst = cast<OperationInst>(use.getOwner());
- SmallVector<ForInst *, 4> loops;
+ SmallVector<OpPointer<AffineForOp>, 4> loops;
getLoopIVs(*userOpInst, &loops);
if (loops.empty())
continue;
- assert(forToNodeMap.count(loops[0]) > 0);
- unsigned userLoopNestId = forToNodeMap[loops[0]];
+ assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
+ unsigned userLoopNestId = forToNodeMap[loops[0]->getInstruction()];
addEdge(node.id, userLoopNestId, value);
}
}
@@ -587,12 +586,12 @@ namespace {
// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
// and operation count) for a loop nest up until the innermost loop body.
struct LoopNestStats {
- // Map from ForInst to immediate child ForInsts in its loop body.
- DenseMap<ForInst *, SmallVector<ForInst *, 2>> loopMap;
- // Map from ForInst to count of operations in its loop body.
- DenseMap<ForInst *, uint64_t> opCountMap;
- // Map from ForInst to its constant trip count.
- DenseMap<ForInst *, uint64_t> tripCountMap;
+ // Map from AffineForOp to immediate child AffineForOps in its loop body.
+ DenseMap<Instruction *, SmallVector<OpPointer<AffineForOp>, 2>> loopMap;
+ // Map from AffineForOp to count of operations in its loop body.
+ DenseMap<Instruction *, uint64_t> opCountMap;
+ // Map from AffineForOp to its constant trip count.
+ DenseMap<Instruction *, uint64_t> tripCountMap;
};
// LoopNestStatsCollector walks a single loop nest and gathers per-loop
@@ -604,23 +603,31 @@ public:
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
- void visitForInst(ForInst *forInst) {
- auto *parentInst = forInst->getParentInst();
+ void visitOperationInst(OperationInst *opInst) {
+ auto forOp = opInst->dyn_cast<AffineForOp>();
+ if (!forOp)
+ return;
+
+ auto *forInst = forOp->getInstruction();
+ auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
- assert(isa<ForInst>(parentInst) && "Expected parent ForInst");
- // Add mapping to 'forInst' from its parent ForInst.
- stats->loopMap[cast<ForInst>(parentInst)].push_back(forInst);
+ assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
+ "Expected parent AffineForOp");
+ // Add mapping to 'forOp' from its parent AffineForOp.
+ stats->loopMap[parentInst].push_back(forOp);
}
- // Record the number of op instructions in the body of 'forInst'.
+
+ // Record the number of op instructions in the body of 'forOp'.
unsigned count = 0;
stats->opCountMap[forInst] = 0;
- for (auto &inst : *forInst->getBody()) {
- if (isa<OperationInst>(&inst))
+ for (auto &inst : *forOp->getBody()) {
+ if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
+ cast<OperationInst>(inst).isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
- // Record trip count for 'forInst'. Set flag if trip count is not constant.
- Optional<uint64_t> maybeConstTripCount = getConstantTripCount(*forInst);
+ // Record trip count for 'forOp'. Set flag if trip count is not constant.
+ Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount.hasValue()) {
hasLoopWithNonConstTripCount = true;
return;
@@ -629,7 +636,7 @@ public:
}
};
-// Computes the total cost of the loop nest rooted at 'forInst'.
+// Computes the total cost of the loop nest rooted at 'forOp'.
// Currently, the total cost is computed by counting the total operation
// instance count (i.e. total number of operations in the loop bodyloop
// operation count * loop trip count) for the entire loop nest.
@@ -637,7 +644,7 @@ public:
// specified in the map when computing the total op instance count.
// NOTE: this is used to compute the cost of computation slices, which are
// sliced along the iteration dimension, and thus reduce the trip count.
-// If 'computeCostMap' is non-null, the total op count for forInsts specified
+// If 'computeCostMap' is non-null, the total op count for forOps specified
// in the map is increased (not overridden) by adding the op count from the
// map to the existing op count for the for loop. This is done before
// multiplying by the loop's trip count, and is used to model the cost of
@@ -645,15 +652,15 @@ public:
// NOTE: this is used to compute the cost of fusing a slice of some loop nest
// within another loop.
static int64_t getComputeCost(
- ForInst *forInst, LoopNestStats *stats,
- llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountOverrideMap,
- DenseMap<ForInst *, int64_t> *computeCostMap) {
- // 'opCount' is the total number operations in one iteration of 'forInst' body
+ Instruction *forInst, LoopNestStats *stats,
+ llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountOverrideMap,
+ DenseMap<Instruction *, int64_t> *computeCostMap) {
+ // 'opCount' is the total number operations in one iteration of 'forOp' body
int64_t opCount = stats->opCountMap[forInst];
if (stats->loopMap.count(forInst) > 0) {
- for (auto *childForInst : stats->loopMap[forInst]) {
- opCount += getComputeCost(childForInst, stats, tripCountOverrideMap,
- computeCostMap);
+ for (auto childForOp : stats->loopMap[forInst]) {
+ opCount += getComputeCost(childForOp->getInstruction(), stats,
+ tripCountOverrideMap, computeCostMap);
}
}
// Add in additional op instances from slice (if specified in map).
@@ -694,18 +701,18 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
return cExpr.getValue();
}
-// Builds a map 'tripCountMap' from ForInst to constant trip count for loop
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
// nest surrounding 'srcAccess' utilizing slice loop bounds in 'sliceState'.
// Returns true on success, false otherwise (if a non-constant trip count
// was encountered).
// TODO(andydavis) Make this work with non-unit step loops.
static bool buildSliceTripCountMap(
OperationInst *srcOpInst, ComputationSliceState *sliceState,
- llvm::SmallDenseMap<ForInst *, uint64_t, 8> *tripCountMap) {
- SmallVector<ForInst *, 4> srcLoopIVs;
+ llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
+ SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
- // Populate map from ForInst -> trip count
+ // Populate map from AffineForOp -> trip count
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
AffineMap lbMap = sliceState->lbs[i];
AffineMap ubMap = sliceState->ubs[i];
@@ -713,7 +720,7 @@ static bool buildSliceTripCountMap(
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
if (srcLoopIVs[i]->hasConstantLowerBound() &&
srcLoopIVs[i]->hasConstantUpperBound()) {
- (*tripCountMap)[srcLoopIVs[i]] =
+ (*tripCountMap)[srcLoopIVs[i]->getInstruction()] =
srcLoopIVs[i]->getConstantUpperBound() -
srcLoopIVs[i]->getConstantLowerBound();
continue;
@@ -723,7 +730,7 @@ static bool buildSliceTripCountMap(
Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
if (!tripCount.hasValue())
return false;
- (*tripCountMap)[srcLoopIVs[i]] = tripCount.getValue();
+ (*tripCountMap)[srcLoopIVs[i]->getInstruction()] = tripCount.getValue();
}
return true;
}
@@ -750,7 +757,7 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
unsigned numOps = ops.size();
assert(numOps > 0);
- std::vector<SmallVector<ForInst *, 4>> loops(numOps);
+ std::vector<SmallVector<OpPointer<AffineForOp>, 4>> loops(numOps);
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
for (unsigned i = 0; i < numOps; ++i) {
getLoopIVs(*ops[i], &loops[i]);
@@ -762,9 +769,8 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
for (unsigned d = 0; d < loopDepthLimit; ++d) {
unsigned i;
for (i = 1; i < numOps; ++i) {
- if (loops[i - 1][d] != loops[i][d]) {
+ if (loops[i - 1][d] != loops[i][d])
break;
- }
}
if (i != numOps)
break;
@@ -871,14 +877,16 @@ static bool getSliceUnion(const ComputationSliceState &sliceStateA,
}
// Creates and returns a private (single-user) memref for fused loop rooted
-// at 'forInst', with (potentially reduced) memref size based on the
+// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
-static Value *createPrivateMemRef(ForInst *forInst,
+static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
OperationInst *srcStoreOpInst,
unsigned dstLoopDepth) {
- // Create builder to insert alloc op just before 'forInst'.
+ auto *forInst = forOp->getInstruction();
+
+ // Create builder to insert alloc op just before 'forOp'.
FuncBuilder b(forInst);
// Builder to create constants at the top level.
FuncBuilder top(forInst->getFunction());
@@ -934,16 +942,16 @@ static Value *createPrivateMemRef(ForInst *forInst,
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
allocOperands.push_back(
- top.create<DimOp>(forInst->getLoc(), oldMemRef, dynamicDimCount++));
+ top.create<DimOp>(forOp->getLoc(), oldMemRef, dynamicDimCount++));
}
- // Create new private memref for fused loop 'forInst'.
+ // Create new private memref for fused loop 'forOp'.
// TODO(andydavis) Create/move alloc ops for private memrefs closer to their
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the function, because loop nests can be reordered
// during the fusion pass.
Value *newMemRef =
- top.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
+ top.create<AllocOp>(forOp->getLoc(), newMemRefType, allocOperands);
// Build an AffineMap to remap access functions based on lower bound offsets.
SmallVector<AffineExpr, 4> remapExprs;
@@ -967,7 +975,7 @@ static Value *createPrivateMemRef(ForInst *forInst,
bool ret =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domInstFilter=*/&*forInst->getBody()->begin());
+ /*domInstFilter=*/&*forOp->getBody()->begin());
assert(ret && "replaceAllMemrefUsesWith should always succeed here");
(void)ret;
return newMemRef;
@@ -975,7 +983,7 @@ static Value *createPrivateMemRef(ForInst *forInst,
// Does the slice have a single iteration?
static uint64_t getSliceIterationCount(
- const llvm::SmallDenseMap<ForInst *, uint64_t, 8> &sliceTripCountMap) {
+ const llvm::SmallDenseMap<Instruction *, uint64_t, 8> &sliceTripCountMap) {
uint64_t iterCount = 1;
for (const auto &count : sliceTripCountMap) {
iterCount *= count.second;
@@ -1030,25 +1038,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
});
// Compute cost of sliced and unsliced src loop nest.
- SmallVector<ForInst *, 4> srcLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
- srcStatsCollector.walk(srcLoopIVs[0]);
+ srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (srcStatsCollector.hasLoopWithNonConstTripCount)
return false;
// Compute cost of dst loop nest.
- SmallVector<ForInst *, 4> dstLoopIVs;
+ SmallVector<OpPointer<AffineForOp>, 4> dstLoopIVs;
getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
LoopNestStats dstLoopNestStats;
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
- dstStatsCollector.walk(dstLoopIVs[0]);
+ dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (dstStatsCollector.hasLoopWithNonConstTripCount)
return false;
@@ -1075,17 +1083,19 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
Optional<unsigned> bestDstLoopDepth = None;
// Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t srcLoopNestCost =
+ getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
// Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
- /*tripCountOverrideMap=*/nullptr,
- /*computeCostMap=*/nullptr);
+ uint64_t dstLoopNestCost =
+ getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
+ /*tripCountOverrideMap=*/nullptr,
+ /*computeCostMap=*/nullptr);
- llvm::SmallDenseMap<ForInst *, uint64_t, 8> sliceTripCountMap;
- DenseMap<ForInst *, int64_t> computeCostMap;
+ llvm::SmallDenseMap<Instruction *, uint64_t, 8> sliceTripCountMap;
+ DenseMap<Instruction *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
MemRefAccess srcAccess(srcOpInst);
// Handle the common case of one dst load without a copy.
@@ -1121,24 +1131,25 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
// The store and loads to this memref will disappear.
if (storeLoadFwdGuaranteed) {
// A single store disappears: -1 for that.
- computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]] = -1;
+ computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
for (auto *loadOp : dstLoadOpInsts) {
- if (auto *loadLoop = dyn_cast_or_null<ForInst>(loadOp->getParentInst()))
- computeCostMap[loadLoop] = -1;
+ auto *parentInst = loadOp->getParentInst();
+ if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
+ computeCostMap[parentInst] = -1;
}
}
// Compute op instance count for the src loop nest with iteration slicing.
int64_t sliceComputeCost =
- getComputeCost(srcLoopIVs[0], &srcLoopNestStats,
+ getComputeCost(srcLoopIVs[0]->getInstruction(), &srcLoopNestStats,
/*tripCountOverrideMap=*/&sliceTripCountMap,
/*computeCostMap=*/&computeCostMap);
// Compute cost of fusion for this depth.
- computeCostMap[dstLoopIVs[i - 1]] = sliceComputeCost;
+ computeCostMap[dstLoopIVs[i - 1]->getInstruction()] = sliceComputeCost;
int64_t fusedLoopNestComputeCost =
- getComputeCost(dstLoopIVs[0], &dstLoopNestStats,
+ getComputeCost(dstLoopIVs[0]->getInstruction(), &dstLoopNestStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
double additionalComputeFraction =
@@ -1211,8 +1222,8 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
<< "\n fused loop nest compute cost: "
<< minFusedLoopNestComputeCost << "\n");
- auto dstMemSize = getMemoryFootprintBytes(*dstLoopIVs[0]);
- auto srcMemSize = getMemoryFootprintBytes(*srcLoopIVs[0]);
+ auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
+ auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Optional<double> storageReduction = None;
@@ -1292,9 +1303,9 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
-// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
-// destination ForInst into which fusion will be attempted.
-// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
+// *) Pop a AffineForOp of the worklist. This 'dstAffineForOp' will be a
+// candidate destination AffineForOp into which fusion will be attempted.
+// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
@@ -1342,7 +1353,7 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
- if (!isa<ForInst>(dstNode->inst))
+ if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
@@ -1375,7 +1386,7 @@ public:
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
auto *srcNode = mdg->getNode(srcId);
// Skip if 'srcNode' is not a loop nest.
- if (!isa<ForInst>(srcNode->inst))
+ if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
continue;
// Skip if 'srcNode' has more than one store to any memref.
// TODO(andydavis) Support fusing multi-output src loop nests.
@@ -1417,25 +1428,26 @@ public:
continue;
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
- auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
+ auto sliceLoopNest = mlir::insertBackwardComputationSlice(
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
- // Move 'dstForInst' before 'insertPointInst' if needed.
- auto *dstForInst = cast<ForInst>(dstNode->inst);
- if (insertPointInst != dstForInst) {
- dstForInst->moveBefore(insertPointInst);
+ // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+ auto dstAffineForOp =
+ cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
+ if (insertPointInst != dstAffineForOp->getInstruction()) {
+ dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
}
// Update edges between 'srcNode' and 'dstNode'.
mdg->updateEdges(srcNode->id, dstNode->id, memref);
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
- sliceCollector.walkForInst(sliceLoopNest);
+ sliceCollector.walk(sliceLoopNest->getInstruction());
// Promote single iteration slice loops to single IV value.
- for (auto *forInst : sliceCollector.forInsts) {
- promoteIfSingleIteration(forInst);
+ for (auto forOp : sliceCollector.forOps) {
+ promoteIfSingleIteration(forOp);
}
- // Create private memref for 'memref' in 'dstForInst'.
+ // Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<OperationInst *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
@@ -1443,7 +1455,7 @@ public:
}
assert(storesForMemref.size() == 1);
auto *newMemRef = createPrivateMemRef(
- dstForInst, storesForMemref[0], bestDstLoopDepth);
+ dstAffineForOp, storesForMemref[0], bestDstLoopDepth);
visitedMemrefs.insert(newMemRef);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId =
@@ -1453,7 +1465,7 @@ public:
// Collect dst loop stats after memref privatizaton transformation.
LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.walkForInst(dstForInst);
+ dstLoopCollector.walk(dstAffineForOp->getInstruction());
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
@@ -1472,7 +1484,7 @@ public:
// function.
if (mdg->canRemoveNode(srcNode->id)) {
mdg->removeNode(srcNode->id);
- cast<ForInst>(srcNode->inst)->erase();
+ srcNode->inst->erase();
}
}
}
OpenPOWER on IntegriCloud