diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Analysis/AffineStructures.h | 10 | ||||
| -rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 28 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 7 |
3 files changed, 26 insertions, 19 deletions
diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 81006e661f9..0ba77cccf90 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -436,11 +436,15 @@ public: /// constant. Asserts if the 'id' is not found. void setIdToConstant(const Value &id, int64_t val); - /// Looks up the identifier with the specified Value. Returns false if not - /// found, true if found. pos is set to the (column) position of the - /// identifier. + /// Looks up the position of the identifier with the specified Value. Returns + /// true if found (false otherwise). `pos' is set to the (column) position of + /// the identifier. bool findId(const Value &id, unsigned *pos) const; + /// Returns true if an identifier with the specified Value exists, false + /// otherwise. + bool containsId(const Value &id) const; + // Add identifiers of the specified kind - specified positions are relative to // the kind of identifier. The coefficient column corresponding to the added // identifier is initialized to zero. 'id' is the Value corresponding to the diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index c6287e281c4..543ae1400d4 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1888,6 +1888,12 @@ bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const { return false; } +bool FlatAffineConstraints::containsId(const Value &id) const { + return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) { + return mayBeId.hasValue() && mayBeId.getValue() == &id; + }); +} + void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { assert(newSymbolCount <= numDims + numSymbols && "invalid separation position"); @@ -2696,19 +2702,21 @@ bool FlatAffineConstraints::unionBoundingBox( boundingLbs.reserve(2 * getNumDimIds()); boundingUbs.reserve(2 * getNumDimIds()); - SmallVector<int64_t, 4> lb, otherLb; - lb.reserve(getNumSymbolIds() + 1); - otherLb.reserve(getNumSymbolIds() + 1); + // To hold lower and upper bounds for each dimension. + SmallVector<int64_t, 4> lb, otherLb, ub, otherUb; + // To compute min of lower bounds and max of upper bounds for each dimension. + SmallVector<int64_t, 4> minLb, maxUb; + // To compute final new lower and upper bounds for the union. + SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols()); + int64_t lbDivisor, otherLbDivisor; for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { - lb.clear(); auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor); if (!extent.hasValue()) // TODO(bondhugula): symbolic extents when necessary. // TODO(bondhugula): handle union if a dimension is unbounded. return false; - otherLb.clear(); auto otherExtent = other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor); if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor) @@ -2717,9 +2725,6 @@ bool FlatAffineConstraints::unionBoundingBox( assert(lbDivisor > 0 && "divisor always expected to be positive"); - // Compute min of lower bounds and max of upper bounds. - SmallVector<int64_t, 4> minLb, maxUb; - auto res = compareBounds(lb, otherLb); // Identify min. if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { @@ -2737,7 +2742,8 @@ bool FlatAffineConstraints::unionBoundingBox( } // Do the same for ub's but max of upper bounds. - SmallVector<int64_t, 4> ub(lb), otherUb(otherLb); + ub = lb; + otherUb = otherLb; ub.back() += extent.getValue() - 1; otherUb.back() += otherExtent.getValue() - 1; @@ -2757,8 +2763,8 @@ bool FlatAffineConstraints::unionBoundingBox( maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); } - SmallVector<int64_t, 8> newLb(getNumCols(), 0); - SmallVector<int64_t, 8> newUb(getNumCols(), 0); + std::fill(newLb.begin(), newLb.end(), 0); + std::fill(newUb.begin(), newUb.end(), 0); // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, // and so it's the divisor for newLb and newUb as well. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index bf2f82e29b1..ba6f79d6d77 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -70,9 +70,7 @@ bool ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { // Add loop bound constraints for values which are loop IVs and equality // constraints for symbols which are constants. for (const auto &value : values) { - unsigned loc; - (void)loc; - assert(cst->findId(*value, &loc)); + assert(cst->containsId(*value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. if (auto *inst = value->getDefiningInst()) { @@ -256,8 +254,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, if (sliceState != nullptr) { // Add dim and symbol slice operands. for (const auto &operand : sliceState->lbOperands[0]) { - unsigned loc; - if (!cst.findId(*operand, &loc)) { + if (!cst.containsId(*operand)) { if (isValidSymbol(operand)) { cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand)); // Check if the symbol is a constant. |

