summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/Utils.cpp')
-rw-r--r--mlir/lib/Analysis/Utils.cpp114
1 files changed, 77 insertions, 37 deletions
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 2e753f8d10a..bdc5d19d0be 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -122,7 +122,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
//
// TODO(bondhugula): extend this to any other memref dereferencing ops
// (dma_start, dma_wait).
-bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
+bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
+ ComputationSliceState *sliceState) {
assert((inst->isa<LoadOp>() || inst->isa<StoreOp>()) &&
"load/store op expected");
@@ -147,18 +148,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
access.getAccessMap(&accessValueMap);
AffineMap accessMap = accessValueMap.getAffineMap();
+ unsigned numDims = accessMap.getNumDims();
+ unsigned numSymbols = accessMap.getNumSymbols();
+ unsigned numOperands = accessValueMap.getNumOperands();
+ // Merge operands with slice operands.
+ SmallVector<Value *, 4> operands;
+ operands.resize(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i)
+ operands[i] = accessValueMap.getOperand(i);
+
+ if (sliceState != nullptr) {
+ // Append slice operands to 'operands' as symbols.
+ operands.append(sliceState->lbOperands[0].begin(),
+ sliceState->lbOperands[0].end());
+ // Update 'numSymbols' by operands from 'sliceState'.
+ numSymbols += sliceState->lbOperands[0].size();
+ }
+
// We'll first associate the dims and symbols of the access map to the dims
// and symbols resp. of cst. This will change below once cst is
// fully constructed out.
- cst.reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0,
- accessValueMap.getOperands());
+ cst.reset(numDims, numSymbols, 0, operands);
// Add equality constraints.
- unsigned numDims = accessMap.getNumDims();
- unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
- if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
+ auto *operand = operands[i];
+ if (auto loop = getForInductionVarOwner(operand)) {
// Note that cst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
@@ -167,7 +183,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
return false;
} else {
// Has to be a valid symbol.
- auto *symbol = accessValueMap.getOperand(i);
+ auto *symbol = operand;
assert(isValidSymbol(symbol));
// Check if the symbol is a constant.
if (auto *inst = symbol->getDefiningInst()) {
@@ -178,6 +194,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
}
}
+ // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
+ if (sliceState != nullptr) {
+ // Add dim and symbol slice operands.
+ for (const auto &operand : sliceState->lbOperands[0]) {
+ unsigned loc;
+ if (!cst.findId(*operand, &loc)) {
+ if (isValidSymbol(operand)) {
+ cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand));
+ loc = cst.getNumDimIds() + cst.getNumSymbolIds() - 1;
+ // Check if the symbol is a constant.
+ if (auto *opInst = operand->getDefiningInst()) {
+ if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
+ cst.setIdToConstant(*operand, constOp->getValue());
+ }
+ }
+ } else {
+ cst.addDimId(cst.getNumDimIds(), const_cast<Value *>(operand));
+ loc = cst.getNumDimIds() - 1;
+ }
+ }
+ }
+ // Add upper/lower bounds from 'sliceState' to 'cst'.
+ if (!cst.addSliceBounds(sliceState->lbs, sliceState->ubs,
+ sliceState->lbOperands[0]))
+ return false;
+ }
+
// Add access function equalities to connect loop IVs to data dimensions.
if (!cst.composeMap(&accessValueMap)) {
LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n");
@@ -233,6 +276,32 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
return llvm::divideCeil(sizeInBits, 8);
}
+// Returns the size of the region.
+Optional<int64_t> MemRefRegion::getRegionSize() {
+ auto memRefType = memref->getType().cast<MemRefType>();
+
+ auto layoutMaps = memRefType.getAffineMaps();
+ if (layoutMaps.size() > 1 ||
+ (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+ LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ return false;
+ }
+
+ // Indices to use for the DmaStart op.
+ // Indices for the original memref being DMAed from/to.
+ SmallVector<Value *, 4> memIndices;
+ // Indices for the faster buffer being DMAed into/from.
+ SmallVector<Value *, 4> bufIndices;
+
+ // Compute the extents of the buffer.
+ Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
+ if (!numElements.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+ return None;
+ }
+ return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
+}
+
/// Returns the size of memref data in bytes if it's statically shaped, None
/// otherwise. If the element of the memref has vector type, takes into account
/// size of the vector as well.
@@ -420,8 +489,6 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess,
// entire destination index set. Subtract out the dependent destination
// iterations from destination index set and check for emptiness --- this is one
// solution.
-// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project
-// out loop IVs we don't care about and produce smaller slice.
OpPointer<AffineForOp> mlir::insertBackwardComputationSlice(
Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth,
ComputationSliceState *sliceState) {
@@ -537,33 +604,6 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A,
return numCommonLoops;
}
-// Returns the size of the region.
-static Optional<int64_t> getRegionSize(const MemRefRegion &region) {
- auto *memref = region.memref;
- auto memRefType = memref->getType().cast<MemRefType>();
-
- auto layoutMaps = memRefType.getAffineMaps();
- if (layoutMaps.size() > 1 ||
- (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
- return false;
- }
-
- // Indices to use for the DmaStart op.
- // Indices for the original memref being DMAed from/to.
- SmallVector<Value *, 4> memIndices;
- // Indices for the faster buffer being DMAed into/from.
- SmallVector<Value *, 4> bufIndices;
-
- // Compute the extents of the buffer.
- Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape();
- if (!numElements.hasValue()) {
- LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
- return None;
- }
- return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
-}
-
Optional<int64_t>
mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp,
int memorySpace) {
@@ -601,7 +641,7 @@ Optional<int64_t> mlir::getMemoryFootprintBytes(const Block &block,
int64_t totalSizeInBytes = 0;
for (const auto &region : regions) {
- auto size = getRegionSize(*region);
+ auto size = region->getRegionSize();
if (!size.hasValue())
return None;
totalSizeInBytes += size.getValue();
OpenPOWER on IntegriCloud