summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2019-02-04 07:58:42 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:09:52 -0700
commitb26900dce55c93043e8f84580df4a1bec65408be (patch)
treee752280cbe8f34910fb950f4104f35357f844a5f /mlir/lib/Transforms
parent870d7783503962a7043b2654ab82a9d4f4f1a961 (diff)
downloadbcm5719-llvm-b26900dce55c93043e8f84580df4a1bec65408be.tar.gz
bcm5719-llvm-b26900dce55c93043e8f84580df4a1bec65408be.zip
Update dma-generate pass to (1) work on blocks of instructions (instead of just
loops), (2) take into account fast memory space capacity and lower 'dmaDepth' to fit, (3) add location information for debug info / errors - change dma-generate pass to work on blocks of instructions (start/end iterators) instead of 'for' loops; complete TODOs - allows DMA generation for straightline blocks of operation instructions interspersed b/w loops - take into account fast memory capacity: check whether memory footprint fits in fastMemoryCapacity parameter, and recurse/lower the depth at which DMA generation is performed until it does fit in the provided memory - add location information to MemRefRegion; any insufficient fast memory capacity errors or debug info w.r.t dma generation shows location information - allow DMA generation pass to be instantiated with a fast memory capacity option (besides command line flag) - change getMemRefRegion to return unique_ptr's - change getMemRefFootprintBytes to work on a 'Block' instead of 'ForInst' - other helper methods; add postDomInstFilter option for replaceAllMemRefUsesWith; drop forInst->walkOps, add Block::walkOps methods Eg. output $ mlir-opt -dma-generate -dma-fast-mem-capacity=1 /tmp/single.mlir /tmp/single.mlir:9:13: error: Total size of all DMA buffers' for this block exceeds fast memory capacity for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { ^ $ mlir-opt -debug-only=dma-generate -dma-generate -dma-fast-mem-capacity=400 /tmp/single.mlir /tmp/single.mlir:9:13: note: 8 KiB of DMA buffers in fast memory space for this block for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0 + 32)(%i1) { PiperOrigin-RevId: 232297044
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp343
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp7
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp5
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp15
4 files changed, 262 insertions, 108 deletions
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 83ec726ec2a..2bbb32036c2 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -47,7 +47,7 @@ static llvm::cl::opt<unsigned> clFastMemorySpace(
llvm::cl::desc("Set fast memory space id for DMA generation"),
llvm::cl::cat(clOptionsCategory));
-static llvm::cl::opt<uint64_t> clFastMemoryCapacity(
+static llvm::cl::opt<unsigned> clFastMemoryCapacity(
"dma-fast-mem-capacity", llvm::cl::Hidden,
llvm::cl::desc("Set fast memory space capacity in KiB"),
llvm::cl::cat(clOptionsCategory));
@@ -57,25 +57,28 @@ namespace {
/// Generates DMAs for memref's living in 'slowMemorySpace' into newly created
/// buffers in 'fastMemorySpace', and replaces memory operations to the former
/// by the latter. Only load op's handled for now.
-/// TODO(bondhugula): extend this to store op's.
+// TODO(bondhugula): We currently can't generate DMAs correctly when stores are
+// strided. Check for strided stores.
+// TODO(mlir-team): we don't insert dealloc's for the DMA buffers; this is thus
+// natural only for scoped allocations.
struct DmaGeneration : public FunctionPass {
- explicit DmaGeneration(unsigned slowMemorySpace = 0,
- unsigned fastMemorySpaceArg = 1,
- int minDmaTransferSize = 1024)
+ explicit DmaGeneration(
+ unsigned slowMemorySpace = 0, unsigned fastMemorySpace = 1,
+ int minDmaTransferSize = 1024,
+ uint64_t fastMemCapacityBytes = std::numeric_limits<uint64_t>::max())
: FunctionPass(&DmaGeneration::passID), slowMemorySpace(slowMemorySpace),
- minDmaTransferSize(minDmaTransferSize) {
- if (clFastMemorySpace.getNumOccurrences() > 0) {
- fastMemorySpace = clFastMemorySpace;
- } else {
- fastMemorySpace = fastMemorySpaceArg;
- }
- }
+ fastMemorySpace(fastMemorySpace),
+ minDmaTransferSize(minDmaTransferSize),
+ fastMemCapacityBytes(fastMemCapacityBytes) {}
PassResult runOnFunction(Function *f) override;
- void runOnAffineForOp(OpPointer<AffineForOp> forOp);
+ bool runOnBlock(Block *block, uint64_t consumedCapacityBytes);
+ uint64_t runOnBlock(Block::iterator begin, Block::iterator end);
- bool generateDma(const MemRefRegion &region, OpPointer<AffineForOp> forOp,
- uint64_t *sizeInBytes);
+ bool generateDma(const MemRefRegion &region, Block *block,
+ Block::iterator begin, Block::iterator end,
+ uint64_t *sizeInBytes, Block::iterator *nBegin,
+ Block::iterator *nEnd);
// List of memory regions to DMA for. We need a map vector to have a
// guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
@@ -93,6 +96,8 @@ struct DmaGeneration : public FunctionPass {
unsigned fastMemorySpace;
// Minimum DMA transfer size supported by the target in bytes.
const int minDmaTransferSize;
+ // Capacity of the faster memory space.
+ uint64_t fastMemCapacityBytes;
// Constant zero index to avoid too many duplicates.
Value *zeroIndex = nullptr;
@@ -110,9 +115,10 @@ char DmaGeneration::passID = 0;
/// TODO(bondhugula): extend this to store op's.
FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
unsigned fastMemorySpace,
- int minDmaTransferSize) {
- return new DmaGeneration(slowMemorySpace, fastMemorySpace,
- minDmaTransferSize);
+ int minDmaTransferSize,
+ uint64_t fastMemCapacityBytes) {
+ return new DmaGeneration(slowMemorySpace, fastMemorySpace, minDmaTransferSize,
+ fastMemCapacityBytes);
}
// Info comprising stride and number of elements transferred every stride.
@@ -192,26 +198,48 @@ static bool getFullMemRefAsRegion(OperationInst *opInst,
return true;
}
-// Creates a buffer in the faster memory space for the specified region;
-// generates a DMA from the lower memory space to this one, and replaces all
-// loads to load from that buffer. Returns false if DMAs could not be generated
-// due to yet unimplemented cases.
-bool DmaGeneration::generateDma(const MemRefRegion &region,
- OpPointer<AffineForOp> forOp,
- uint64_t *sizeInBytes) {
- auto *forInst = forOp->getInstruction();
+static void emitNoteForBlock(const Block &block, const Twine &message) {
+ auto *inst = block.getContainingInst();
+ if (!inst) {
+ block.getFunction()->emitNote(message);
+ } else {
+ inst->emitNote(message);
+ }
+}
+
+/// Creates a buffer in the faster memory space for the specified region;
+/// generates a DMA from the lower memory space to this one, and replaces all
+/// loads to load from that buffer. Returns false if DMAs could not be generated
+/// due to yet unimplemented cases. `begin` and `end` specify the insertion
+/// points where the incoming DMAs and outgoing DMAs, respectively, should
+/// be inserted (the insertion happens right before the insertion point). Since
+/// `begin` can itself be invalidated due to the memref rewriting done from this
+/// method, the output argument `nBegin` is set to its replacement (set
+/// to `begin` if no invalidation happens). Since outgoing DMAs are inserted at
+/// `end`, the output argument `nEnd` is set to the one following the original
+/// end (since the latter could have been invalidated/replaced). `sizeInBytes`
+/// is set to the size of the DMA buffer allocated.
+bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
+ Block::iterator begin, Block::iterator end,
+ uint64_t *sizeInBytes, Block::iterator *nBegin,
+ Block::iterator *nEnd) {
+ *nBegin = begin;
+ *nEnd = end;
+
+ if (begin == end)
+ return true;
// DMAs for read regions are going to be inserted just before the for loop.
- FuncBuilder prologue(forInst);
+ FuncBuilder prologue(block, begin);
// DMAs for write regions are going to be inserted just after the for loop.
- FuncBuilder epilogue(forInst->getBlock(),
- std::next(Block::iterator(forInst)));
+ FuncBuilder epilogue(block, end);
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
- FuncBuilder top(forInst->getFunction());
+ auto *func = block->getFunction();
+ FuncBuilder top(func);
- auto loc = forInst->getLoc();
+ auto loc = region.loc;
auto *memref = region.memref;
auto memRefType = memref->getType().cast<MemRefType>();
@@ -310,21 +338,17 @@ bool DmaGeneration::generateDma(const MemRefRegion &region,
auto fastMemRefType = top.getMemRefType(
fastBufferShape, memRefType.getElementType(), {}, fastMemorySpace);
- LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: ");
- LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n");
-
// Create the fast memory space buffer just before the 'for' instruction.
fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
// fastMemRefType is a constant shaped memref.
*sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
- LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type ";
+ LLVM_DEBUG(emitNoteForBlock(*block, "Creating DMA buffer of type ");
fastMemRefType.dump();
llvm::dbgs()
- << " and size " << Twine(llvm::divideCeil(*sizeInBytes, 1024))
+ << " and of size " << Twine(llvm::divideCeil(*sizeInBytes, 1024))
<< " KiB\n";);
-
} else {
// Reuse the one already created.
fastMemRef = fastBufferMap[memref];
@@ -336,9 +360,6 @@ bool DmaGeneration::generateDma(const MemRefRegion &region,
auto numElementsSSA =
top.create<ConstantIndexOp>(loc, numElements.getValue());
- // TODO(bondhugula): check for transfer sizes not being a multiple of
- // minDmaTransferSize and handle them appropriately.
-
SmallVector<StrideInfo, 4> strideInfos;
getMultiLevelStrides(region, fastBufferShape, &strideInfos);
@@ -357,6 +378,12 @@ bool DmaGeneration::generateDma(const MemRefRegion &region,
top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
}
+ // Record the last instruction just before the point where we insert the
+ // outgoing DMAs. We later do the memref replacement later only in [begin,
+ // postDomFilter] so that the original memref's in the DMA ops themselves
+ // don't get replaced.
+ auto postDomFilter = std::prev(end);
+
if (!region.isWrite()) {
// DMA non-blocking read from original buffer to fast buffer.
b->create<DmaStartOp>(loc, memref, memIndices, fastMemRef, bufIndices,
@@ -364,9 +391,13 @@ bool DmaGeneration::generateDma(const MemRefRegion &region,
numEltPerStride);
} else {
// DMA non-blocking write from fast buffer to the original memref.
- b->create<DmaStartOp>(loc, fastMemRef, bufIndices, memref, memIndices,
- numElementsSSA, tagMemRef, zeroIndex, stride,
- numEltPerStride);
+ auto op = b->create<DmaStartOp>(loc, fastMemRef, bufIndices, memref,
+ memIndices, numElementsSSA, tagMemRef,
+ zeroIndex, stride, numEltPerStride);
+ // Since new ops are being appended (for outgoing DMAs), adjust the end to
+ // mark end of range of the original.
+ if (*nEnd == end)
+ *nEnd = Block::iterator(op->getInstruction());
}
// Matching DMA wait to block on completion; tag always has a 0 index.
@@ -389,45 +420,151 @@ bool DmaGeneration::generateDma(const MemRefRegion &region,
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
- // *Only* those uses within the body of 'forOp' are replaced.
+
+ // Record the begin since it may be invalidated by memref replacement.
+ Block::iterator prev;
+ bool wasAtStartOfBlock = (begin == block->begin());
+ if (!wasAtStartOfBlock)
+ prev = std::prev(begin);
+
+ // *Only* those uses within the range [begin, end) of 'block' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domInstFilter=*/&*forOp->getBody()->begin());
+ /*domInstFilter=*/&*begin,
+ /*postDomInstFilter=*/&*postDomFilter);
+
+ *nBegin = wasAtStartOfBlock ? block->begin() : std::next(prev);
+
return true;
}
-// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
-void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
- // For now (for testing purposes), we'll run this on the outermost among 'for'
- // inst's with unit stride, i.e., right at the top of the tile if tiling has
- // been done. In the future, the DMA generation has to be done at a level
- // where the generated data fits in a higher level of the memory hierarchy; so
- // the pass has to be instantiated with additional information that we aren't
- // provided with at the moment.
- if (forOp->getStep() != 1) {
- auto *forBody = forOp->getBody();
- if (forBody->empty())
- return;
- if (auto innerFor =
- cast<OperationInst>(forBody->front()).dyn_cast<AffineForOp>()) {
- runOnAffineForOp(innerFor);
+/// Generate DMAs for this block. The block is partitioned into separate
+/// `regions`; each region is either a sequence of one or more instructions
+/// starting and ending with a load or store op, or just a loop (which could
+/// have other loops nested within). Returns false on an error, true otherwise.
+bool DmaGeneration::runOnBlock(Block *block, uint64_t consumedCapacityBytes) {
+ block->dump();
+ if (block->empty())
+ return true;
+
+ uint64_t priorConsumedCapacityBytes = consumedCapacityBytes;
+
+ // Every loop in the block starts and ends a region. A contiguous sequence of
+ // operation instructions starting and ending with a load/store op is also
+ // identified as a region. Straightline code (contiguous chunks of operation
+ // instructions) are always assumed to not exhaust memory. As a result, this
+ // approach is conservative in some cases at the moment, we do a check later
+ // and report an error with location info.
+ // TODO(bondhugula): An 'if' instruction is being treated similar to an
+ // operation instruction. 'if''s could have 'for's in them; treat them
+ // separately.
+
+ // Get to the first load, store, or for op.
+ auto curBegin =
+ std::find_if(block->begin(), block->end(), [&](const Instruction &inst) {
+ return inst.isa<LoadOp>() || inst.isa<StoreOp>() ||
+ inst.isa<AffineForOp>();
+ });
+
+ for (auto it = curBegin; it != block->end(); ++it) {
+ if (auto forOp = it->dyn_cast<AffineForOp>()) {
+ // We'll assume for now that loops with steps are tiled loops, and so DMAs
+ // are not performed for that depth, but only further inside.
+ // If the memory footprint of the 'for' loop is higher than fast memory
+ // capacity (when provided), we recurse to DMA at an inner level until
+ // we find a depth at which footprint fits in the capacity. If the
+ // footprint can't be calcuated, we assume for now it fits.
+
+ // Returns true if the footprint is known to exceed capacity.
+ auto exceedsCapacity = [&](OpPointer<AffineForOp> forOp) {
+ Optional<int64_t> footprint;
+ return ((footprint = getMemoryFootprintBytes(forOp, 0)).hasValue() &&
+ consumedCapacityBytes +
+ static_cast<uint64_t>(footprint.getValue()) >
+ fastMemCapacityBytes);
+ };
+
+ if (forOp->getStep() != 1 || exceedsCapacity(forOp)) {
+ // We'll split and do the DMAs one or more levels inside for forInst
+ consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it);
+ // Recurse onto the body of this loop.
+ runOnBlock(forOp->getBody(), consumedCapacityBytes);
+ // The next region starts right after the 'for' instruction.
+ curBegin = std::next(it);
+ } else {
+ // We have enough capacity, i.e., DMAs will be computed for the portion
+ // of the block until 'it', and for the 'for' loop. For the latter, they
+ // are placed just before this loop (for incoming DMAs) and right after
+ // (for outgoing ones).
+ consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it);
+
+ // Inner loop DMAs have their own scope - we don't thus update consumed
+ // capacity. The footprint check above guarantees this inner loop's
+ // footprint fits.
+ runOnBlock(/*begin=*/it, /*end=*/std::next(it));
+ curBegin = std::next(it);
+ }
+ } else if (!it->isa<LoadOp>() && !it->isa<StoreOp>()) {
+ consumedCapacityBytes += runOnBlock(/*begin=*/curBegin, /*end=*/it);
+ curBegin = std::next(it);
}
- return;
}
- // DMAs will be generated for this depth, i.e., for all data accessed by this
- // loop.
- unsigned dmaDepth = getNestingDepth(*forOp->getInstruction());
+ // Generate the DMA for the final region.
+ if (curBegin != block->end()) {
+ // Can't be a terminator because it would have been skipped above.
+ assert(!curBegin->isTerminator() && "can't be a terminator");
+ consumedCapacityBytes +=
+ runOnBlock(/*begin=*/curBegin, /*end=*/block->end());
+ }
+
+ if (llvm::DebugFlag) {
+ uint64_t thisBlockDmaSizeBytes =
+ consumedCapacityBytes - priorConsumedCapacityBytes;
+ if (thisBlockDmaSizeBytes > 0) {
+ emitNoteForBlock(
+ *block,
+ Twine(llvm::divideCeil(thisBlockDmaSizeBytes, 1024)) +
+ " KiB of DMA buffers in fast memory space for this block\n");
+ }
+ }
+
+ if (consumedCapacityBytes > fastMemCapacityBytes) {
+ StringRef str = "Total size of all DMA buffers' for this block "
+ "exceeds fast memory capacity\n";
+ if (auto *inst = block->getContainingInst())
+ inst->emitError(str);
+ else
+ block->getFunction()->emitError(str);
+ return false;
+ }
+
+ return true;
+}
+
+/// Generates DMAs for a contiguous sequence of instructions in `block` in the
+/// iterator range [begin, end). Returns the total size of the DMA buffers used.
+uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
+ if (begin == end)
+ return 0;
+
+ assert(begin->getBlock() == std::prev(end)->getBlock() &&
+ "Inconsistent args");
+
+ Block *block = begin->getBlock();
+
+ // DMAs will be generated for this depth, i.e., symbolic in all loops
+ // surrounding the region of this block.
+ unsigned dmaDepth = getNestingDepth(*begin);
readRegions.clear();
writeRegions.clear();
fastBufferMap.clear();
- // Walk this 'for' instruction to gather all memory regions.
- forOp->walkOps([&](OperationInst *opInst) {
- // Gather regions to promote to buffers in faster memory space.
- // TODO(bondhugula): handle store op's; only load's handled for now.
+ // Walk this range of instructions to gather all memory regions.
+ block->walk(begin, end, [&](OperationInst *opInst) {
+ // Gather regions to allocate to buffers in faster memory space.
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
@@ -439,18 +576,15 @@ void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
return;
}
- // TODO(bondhugula): eventually, we need to be performing a union across
- // all regions for a given memref instead of creating one region per
- // memory op. This way we would be allocating O(num of memref's) sets
- // instead of O(num of load/store op's).
- auto region = std::make_unique<MemRefRegion>();
- if (!getMemRefRegion(opInst, dmaDepth, region.get())) {
+ // Compute the MemRefRegion accessed.
+ auto region = getMemRefRegion(opInst, dmaDepth);
+ if (!region) {
LLVM_DEBUG(llvm::dbgs()
<< "Error obtaining memory region: semi-affine maps?\n");
LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(
- forOp->emitError("Non-constant memref sizes not yet supported"));
+ opInst->emitError("Non-constant memref sizes not yet supported"));
return;
}
}
@@ -477,12 +611,12 @@ void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
return false;
// Perform a union with the existing region.
- if (!(*it).second->unionBoundingBox(*region)) {
+ if (!it->second->unionBoundingBox(*region)) {
LLVM_DEBUG(llvm::dbgs()
- << "Memory region bounding box failed"
+ << "Memory region bounding box failed; "
"over-approximating to the entire memref\n");
if (!getFullMemRefAsRegion(opInst, dmaDepth, region.get())) {
- LLVM_DEBUG(forOp->emitError(
+ LLVM_DEBUG(opInst->emitError(
"Non-constant memref sizes not yet supported"));
}
}
@@ -500,48 +634,59 @@ void DmaGeneration::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
}
});
- uint64_t totalSizeInBytes = 0;
-
+ uint64_t totalDmaBuffersSizeInBytes = 0;
bool ret = true;
auto processRegions =
[&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
&regions) {
for (const auto &regionEntry : regions) {
uint64_t sizeInBytes;
- bool iRet = generateDma(*regionEntry.second, forOp, &sizeInBytes);
- if (iRet)
- totalSizeInBytes += sizeInBytes;
+ Block::iterator nBegin, nEnd;
+ bool iRet = generateDma(*regionEntry.second, block, begin, end,
+ &sizeInBytes, &nBegin, &nEnd);
+ if (iRet) {
+ begin = nBegin;
+ end = nEnd;
+ totalDmaBuffersSizeInBytes += sizeInBytes;
+ }
ret = ret & iRet;
}
};
processRegions(readRegions);
processRegions(writeRegions);
+
if (!ret) {
- forOp->emitError("DMA generation failed for one or more memref's\n");
- return;
+ begin->emitError(
+ "DMA generation failed for one or more memref's in this block\n");
+ return totalDmaBuffersSizeInBytes;
}
- LLVM_DEBUG(llvm::dbgs() << Twine(llvm::divideCeil(totalSizeInBytes, 1024))
- << " KiB of DMA buffers in fast memory space\n";);
-
- if (clFastMemoryCapacity && totalSizeInBytes > clFastMemoryCapacity) {
- // TODO(bondhugula): selecting the DMA depth so that the result DMA buffers
- // fit in fast memory is a TODO - not complex.
- forOp->emitError(
- "Total size of all DMA buffers' exceeds memory capacity\n");
+
+ // For a range of operation instructions, a note will be emitted at the
+ // caller.
+ OpPointer<AffineForOp> forOp;
+ if (llvm::DebugFlag && (forOp = begin->dyn_cast<AffineForOp>())) {
+ forOp->emitNote(
+ Twine(llvm::divideCeil(totalDmaBuffersSizeInBytes, 1024)) +
+ " KiB of DMA buffers in fast memory space for this block\n");
}
+
+ return totalDmaBuffersSizeInBytes;
}
PassResult DmaGeneration::runOnFunction(Function *f) {
FuncBuilder topBuilder(f);
-
zeroIndex = topBuilder.create<ConstantIndexOp>(f->getLoc(), 0);
+ if (clFastMemorySpace.getNumOccurrences() > 0) {
+ fastMemorySpace = clFastMemorySpace;
+ }
+
+ if (clFastMemoryCapacity.getNumOccurrences() > 0) {
+ fastMemCapacityBytes = clFastMemoryCapacity * 1024;
+ }
+
for (auto &block : *f) {
- for (auto &inst : block) {
- if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>()) {
- runOnAffineForOp(forOp);
- }
- }
+ runOnBlock(&block, /*consumedCapacityBytes=*/0);
}
// This function never leaves the IR in an invalid state.
return success();
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 5091e3ceb33..162e0e3b7f6 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -929,8 +929,7 @@ static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
unsigned rank = oldMemRefType.getRank();
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
- MemRefRegion region;
- getMemRefRegion(srcStoreOpInst, dstLoopDepth, &region);
+ auto region = getMemRefRegion(srcStoreOpInst, dstLoopDepth);
SmallVector<int64_t, 4> newShape;
std::vector<SmallVector<int64_t, 4>> lbs;
SmallVector<int64_t, 8> lbDivisors;
@@ -938,11 +937,11 @@ static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
// Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
// by 'srcStoreOpInst' at depth 'dstLoopDepth'.
Optional<int64_t> numElements =
- region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
+ region->getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
assert(numElements.hasValue() &&
"non-constant number of elts in local buffer");
- const FlatAffineConstraints *cst = region.getConstraints();
+ const FlatAffineConstraints *cst = region->getConstraints();
// 'outerIVs' holds the values that this memory region is symbolic/paramteric
// on; this would correspond to loop IVs surrounding the level at which the
// slice is being materialized.
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index 4191a9cc279..e6ce273b532 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -178,9 +178,8 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
// is trivially loading from a single location at that depth; so there
// isn't a need to call isRangeOneToOne.
if (getNestingDepth(*storeOpInst) < loadOpDepth) {
- MemRefRegion region;
- getMemRefRegion(loadOpInst, nsLoops, &region);
- if (!region.getConstraints()->isRangeOneToOne(
+ auto region = getMemRefRegion(loadOpInst, nsLoops);
+ if (!region->getConstraints()->isRangeOneToOne(
/*start=*/0, /*limit=*/loadOp->getMemRefType().getRank()))
break;
}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 819f1a59b6f..732062a8b97 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -48,7 +48,8 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
- const Instruction *domInstFilter) {
+ const Instruction *domInstFilter,
+ const Instruction *postDomInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
@@ -66,9 +67,14 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
newMemRef->getType().cast<MemRefType>().getElementType());
std::unique_ptr<DominanceInfo> domInfo;
+ std::unique_ptr<PostDominanceInfo> postDomInfo;
if (domInstFilter)
domInfo = std::make_unique<DominanceInfo>(domInstFilter->getFunction());
+ if (postDomInstFilter)
+ postDomInfo =
+ std::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction());
+
// The ops where memref replacement succeeds are replaced with new ones.
SmallVector<OperationInst *, 8> opsToErase;
@@ -81,6 +87,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
continue;
+ // Skip this use if it's not post-dominated by postDomInstFilter.
+ if (postDomInstFilter &&
+ !postDomInfo->postDominates(postDomInstFilter, opInst))
+ continue;
+
// Check if the memref was used in a non-deferencing context. It is fine for
// the memref to be used in a non-deferencing way outside of the region
// where this replacement is happening.
@@ -167,7 +178,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
res->replaceAllUsesWith(repOp->getResult(r++));
}
// Collect and erase at the end since one of these op's could be
- // domInstFilter!
+ // domInstFilter or postDomInstFilter as well!
opsToErase.push_back(opInst);
}
OpenPOWER on IntegriCloud