summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/AffineStructures.h10
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp28
-rw-r--r--mlir/lib/Analysis/Utils.cpp7
3 files changed, 26 insertions, 19 deletions
diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 81006e661f9..0ba77cccf90 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -436,11 +436,15 @@ public:
/// constant. Asserts if the 'id' is not found.
void setIdToConstant(const Value &id, int64_t val);
- /// Looks up the identifier with the specified Value. Returns false if not
- /// found, true if found. pos is set to the (column) position of the
- /// identifier.
+ /// Looks up the position of the identifier with the specified Value. Returns
+ /// true if found (false otherwise). `pos' is set to the (column) position of
+ /// the identifier.
bool findId(const Value &id, unsigned *pos) const;
+ /// Returns true if an identifier with the specified Value exists, false
+ /// otherwise.
+ bool containsId(const Value &id) const;
+
// Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. The coefficient column corresponding to the added
// identifier is initialized to zero. 'id' is the Value corresponding to the
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index c6287e281c4..543ae1400d4 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1888,6 +1888,12 @@ bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const {
return false;
}
+bool FlatAffineConstraints::containsId(const Value &id) const {
+ return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) {
+ return mayBeId.hasValue() && mayBeId.getValue() == &id;
+ });
+}
+
void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
assert(newSymbolCount <= numDims + numSymbols &&
"invalid separation position");
@@ -2696,19 +2702,21 @@ bool FlatAffineConstraints::unionBoundingBox(
boundingLbs.reserve(2 * getNumDimIds());
boundingUbs.reserve(2 * getNumDimIds());
- SmallVector<int64_t, 4> lb, otherLb;
- lb.reserve(getNumSymbolIds() + 1);
- otherLb.reserve(getNumSymbolIds() + 1);
+ // To hold lower and upper bounds for each dimension.
+ SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
+ // To compute min of lower bounds and max of upper bounds for each dimension.
+ SmallVector<int64_t, 4> minLb, maxUb;
+ // To compute final new lower and upper bounds for the union.
+ SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
+
int64_t lbDivisor, otherLbDivisor;
for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
- lb.clear();
auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor);
if (!extent.hasValue())
// TODO(bondhugula): symbolic extents when necessary.
// TODO(bondhugula): handle union if a dimension is unbounded.
return false;
- otherLb.clear();
auto otherExtent =
other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor);
if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor)
@@ -2717,9 +2725,6 @@ bool FlatAffineConstraints::unionBoundingBox(
assert(lbDivisor > 0 && "divisor always expected to be positive");
- // Compute min of lower bounds and max of upper bounds.
- SmallVector<int64_t, 4> minLb, maxUb;
-
auto res = compareBounds(lb, otherLb);
// Identify min.
if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
@@ -2737,7 +2742,8 @@ bool FlatAffineConstraints::unionBoundingBox(
}
// Do the same for ub's but max of upper bounds.
- SmallVector<int64_t, 4> ub(lb), otherUb(otherLb);
+ ub = lb;
+ otherUb = otherLb;
ub.back() += extent.getValue() - 1;
otherUb.back() += otherExtent.getValue() - 1;
@@ -2757,8 +2763,8 @@ bool FlatAffineConstraints::unionBoundingBox(
maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
}
- SmallVector<int64_t, 8> newLb(getNumCols(), 0);
- SmallVector<int64_t, 8> newUb(getNumCols(), 0);
+ std::fill(newLb.begin(), newLb.end(), 0);
+ std::fill(newUb.begin(), newUb.end(), 0);
// The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
// and so it's the divisor for newLb and newUb as well.
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index bf2f82e29b1..ba6f79d6d77 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -70,9 +70,7 @@ bool ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
// Add loop bound constraints for values which are loop IVs and equality
// constraints for symbols which are constants.
for (const auto &value : values) {
- unsigned loc;
- (void)loc;
- assert(cst->findId(*value, &loc));
+ assert(cst->containsId(*value) && "value expected to be present");
if (isValidSymbol(value)) {
// Check if the symbol is a constant.
if (auto *inst = value->getDefiningInst()) {
@@ -256,8 +254,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
if (sliceState != nullptr) {
// Add dim and symbol slice operands.
for (const auto &operand : sliceState->lbOperands[0]) {
- unsigned loc;
- if (!cst.findId(*operand, &loc)) {
+ if (!cst.containsId(*operand)) {
if (isValidSymbol(operand)) {
cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand));
// Check if the symbol is a constant.
OpenPOWER on IntegriCloud