diff options
Diffstat (limited to 'mlir/lib/Analysis/Utils.cpp')
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 114 |
1 files changed, 77 insertions, 37 deletions
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 2e753f8d10a..bdc5d19d0be 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -122,7 +122,8 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { +bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, + ComputationSliceState *sliceState) { assert((inst->isa<LoadOp>() || inst->isa<StoreOp>()) && "load/store op expected"); @@ -147,18 +148,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { access.getAccessMap(&accessValueMap); AffineMap accessMap = accessValueMap.getAffineMap(); + unsigned numDims = accessMap.getNumDims(); + unsigned numSymbols = accessMap.getNumSymbols(); + unsigned numOperands = accessValueMap.getNumOperands(); + // Merge operands with slice operands. + SmallVector<Value *, 4> operands; + operands.resize(numOperands); + for (unsigned i = 0; i < numOperands; ++i) + operands[i] = accessValueMap.getOperand(i); + + if (sliceState != nullptr) { + // Append slice operands to 'operands' as symbols. + operands.append(sliceState->lbOperands[0].begin(), + sliceState->lbOperands[0].end()); + // Update 'numSymbols' by operands from 'sliceState'. + numSymbols += sliceState->lbOperands[0].size(); + } + // We'll first associate the dims and symbols of the access map to the dims // and symbols resp. of cst. This will change below once cst is // fully constructed out. - cst.reset(accessMap.getNumDims(), accessMap.getNumSymbols(), 0, - accessValueMap.getOperands()); + cst.reset(numDims, numSymbols, 0, operands); // Add equality constraints. - unsigned numDims = accessMap.getNumDims(); - unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { + auto *operand = operands[i]; + if (auto loop = getForInductionVarOwner(operand)) { // Note that cst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way @@ -167,7 +183,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { return false; } else { // Has to be a valid symbol. - auto *symbol = accessValueMap.getOperand(i); + auto *symbol = operand; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *inst = symbol->getDefiningInst()) { @@ -178,6 +194,33 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) { } } + // Add lower/upper bounds on loop IVs using bounds from 'sliceState'. + if (sliceState != nullptr) { + // Add dim and symbol slice operands. + for (const auto &operand : sliceState->lbOperands[0]) { + unsigned loc; + if (!cst.findId(*operand, &loc)) { + if (isValidSymbol(operand)) { + cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand)); + loc = cst.getNumDimIds() + cst.getNumSymbolIds() - 1; + // Check if the symbol is a constant. + if (auto *opInst = operand->getDefiningInst()) { + if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) { + cst.setIdToConstant(*operand, constOp->getValue()); + } + } + } else { + cst.addDimId(cst.getNumDimIds(), const_cast<Value *>(operand)); + loc = cst.getNumDimIds() - 1; + } + } + } + // Add upper/lower bounds from 'sliceState' to 'cst'. + if (!cst.addSliceBounds(sliceState->lbs, sliceState->ubs, + sliceState->lbOperands[0])) + return false; + } + // Add access function equalities to connect loop IVs to data dimensions. if (!cst.composeMap(&accessValueMap)) { LLVM_DEBUG(llvm::dbgs() << "getMemRefRegion: compose affine map failed\n"); @@ -233,6 +276,32 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { return llvm::divideCeil(sizeInBits, 8); } +// Returns the size of the region. +Optional<int64_t> MemRefRegion::getRegionSize() { + auto memRefType = memref->getType().cast<MemRefType>(); + + auto layoutMaps = memRefType.getAffineMaps(); + if (layoutMaps.size() > 1 || + (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + return false; + } + + // Indices to use for the DmaStart op. + // Indices for the original memref being DMAed from/to. + SmallVector<Value *, 4> memIndices; + // Indices for the faster buffer being DMAed into/from. + SmallVector<Value *, 4> bufIndices; + + // Compute the extents of the buffer. + Optional<int64_t> numElements = getConstantBoundingSizeAndShape(); + if (!numElements.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + return None; + } + return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); +} + /// Returns the size of memref data in bytes if it's statically shaped, None /// otherwise. If the element of the memref has vector type, takes into account /// size of the vector as well. @@ -420,8 +489,6 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // entire destination index set. Subtract out the dependent destination // iterations from destination index set and check for emptiness --- this is one // solution. -// TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project -// out loop IVs we don't care about and produce smaller slice. OpPointer<AffineForOp> mlir::insertBackwardComputationSlice( Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { @@ -537,33 +604,6 @@ unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A, return numCommonLoops; } -// Returns the size of the region. -static Optional<int64_t> getRegionSize(const MemRefRegion ®ion) { - auto *memref = region.memref; - auto memRefType = memref->getType().cast<MemRefType>(); - - auto layoutMaps = memRefType.getAffineMaps(); - if (layoutMaps.size() > 1 || - (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { - LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); - return false; - } - - // Indices to use for the DmaStart op. - // Indices for the original memref being DMAed from/to. - SmallVector<Value *, 4> memIndices; - // Indices for the faster buffer being DMAed into/from. - SmallVector<Value *, 4> bufIndices; - - // Compute the extents of the buffer. - Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(); - if (!numElements.hasValue()) { - LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); - return None; - } - return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); -} - Optional<int64_t> mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp, int memorySpace) { @@ -601,7 +641,7 @@ Optional<int64_t> mlir::getMemoryFootprintBytes(const Block &block, int64_t totalSizeInBytes = 0; for (const auto ®ion : regions) { - auto size = getRegionSize(*region); + auto size = region->getRegionSize(); if (!size.hasValue()) return None; totalSizeInBytes += size.getValue(); |

