summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-09 12:55:05 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-09 12:55:40 -0800
commit7be6a40ab9b914b14ab61ae13e47e0bb8237e74d (patch)
tree5ba05e5e9d2e88714654e891af3d2b69bbd923d2
parent56da74476c48cfb6af1eb32ad191c3463a7e10e3 (diff)
downloadbcm5719-llvm-7be6a40ab9b914b14ab61ae13e47e0bb8237e74d.tar.gz
bcm5719-llvm-7be6a40ab9b914b14ab61ae13e47e0bb8237e74d.zip
Add new indexed_accessor_range_base and indexed_accessor_range classes that simplify defining index-able ranges.
Many ranges want similar functionality from a range type(e.g. slice/drop_front/operator[]/etc.), so these classes provide a generic implementation that may be used by many different types of ranges. This removes some code duplication, and also empowers many of the existing range types in MLIR(e.g. result type ranges, operand ranges, ElementsAttr ranges, etc.). This change only updates RegionRange and ValueRange, more ranges will be updated in followup commits. PiperOrigin-RevId: 284615679
-rw-r--r--mlir/include/mlir/IR/Attributes.h6
-rw-r--r--mlir/include/mlir/IR/Block.h4
-rw-r--r--mlir/include/mlir/IR/Operation.h67
-rw-r--r--mlir/include/mlir/IR/Region.h45
-rw-r--r--mlir/include/mlir/Support/STLExtras.h140
-rw-r--r--mlir/lib/IR/Attributes.cpp2
-rw-r--r--mlir/lib/IR/Operation.cpp57
-rw-r--r--mlir/lib/IR/Region.cpp26
8 files changed, 206 insertions, 141 deletions
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 3968d44dd37..59df75dc483 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -641,12 +641,12 @@ protected:
/// Return the current index for this iterator, adjusted for the case of a
/// splat.
ptrdiff_t getDataIndex() const {
- bool isSplat = this->object.getInt();
+ bool isSplat = this->base.getInt();
return isSplat ? 0 : this->index;
}
- /// Return the data object pointer.
- const char *getData() const { return this->object.getPointer(); }
+ /// Return the data base pointer.
+ const char *getData() const { return this->base.getPointer(); }
};
} // namespace detail
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index f01f1915d44..532352eb501 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -460,9 +460,9 @@ public:
Block *>(object, index) {}
SuccessorIterator(const SuccessorIterator &other)
- : SuccessorIterator(other.object, other.index) {}
+ : SuccessorIterator(other.base, other.index) {}
- Block *operator*() const { return this->object->getSuccessor(this->index); }
+ Block *operator*() const { return this->base->getSuccessor(this->index); }
/// Get the successor number in the terminator.
unsigned getSuccessorIndex() const { return this->index; }
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 75ea9727d4a..037c4fcf3df 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -668,7 +668,7 @@ public:
: indexed_accessor_iterator<OperandIterator, Operation *, Value *,
Value *, Value *>(object, index) {}
- Value *operator*() const { return this->object->getOperand(this->index); }
+ Value *operator*() const { return this->base->getOperand(this->index); }
};
/// This class implements the operand type iterators for the Operation
@@ -721,11 +721,11 @@ class ResultIterator final
Value *, Value *> {
public:
/// Initializes the result iterator to the specified index.
- ResultIterator(Operation *object, unsigned index)
+ ResultIterator(Operation *base, unsigned index)
: indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
- Value *>(object, index) {}
+ Value *>(base, index) {}
- Value *operator*() const { return this->object->getResult(this->index); }
+ Value *operator*() const { return this->base->getResult(this->index); }
};
/// This class implements the result type iterators for the Operation
@@ -799,15 +799,19 @@ inline auto Operation::getResultTypes() -> result_type_range {
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
-class ValueRange {
+class ValueRange
+ : public detail::indexed_accessor_range_base<
+ ValueRange,
+ llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
+ Value *, Value *> {
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
public:
- ValueRange(const ValueRange &) = default;
- ValueRange(ValueRange &&) = default;
- ValueRange &operator=(const ValueRange &) = default;
+ using detail::indexed_accessor_range_base<
+ ValueRange, OwnerT, Value *, Value *,
+ Value *>::indexed_accessor_range_base;
template <typename Arg,
typename = typename std::enable_if_t<
@@ -822,46 +826,15 @@ public:
ValueRange(iterator_range<OperandIterator> values);
ValueRange(iterator_range<ResultIterator> values);
- /// An iterator element of this range.
- class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Value *,
- Value *, Value *> {
- public:
- Value *operator*() const;
-
- private:
- Iterator(OwnerT owner, unsigned curIndex);
-
- /// Allow access to the constructor.
- friend ValueRange;
- };
-
- Iterator begin() const { return Iterator(owner, 0); }
- Iterator end() const { return Iterator(owner, count); }
- Value *operator[](unsigned index) const {
- assert(index < size() && "invalid index for value range");
- return *std::next(begin(), index);
- }
-
- /// Return the size of this range.
- size_t size() const { return count; }
-
- /// Return if the range is empty.
- bool empty() const { return size() == 0; }
-
- /// Drop the first N elements, and keep M elements.
- ValueRange slice(unsigned n, unsigned m) const;
- /// Drop the first n elements.
- ValueRange drop_front(unsigned n = 1) const;
- /// Drop the last n elements.
- ValueRange drop_back(unsigned n = 1) const;
-
private:
- ValueRange(OwnerT owner, unsigned count) : owner(owner), count(count) {}
-
- /// The object that owns the provided range of values.
- OwnerT owner;
- /// The size from the owning range.
- unsigned count;
+ /// See `detail::indexed_accessor_range_base` for details.
+ static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
+ /// See `detail::indexed_accessor_range_base` for details.
+ static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
+
+ /// Allow access to `offset_base` and `dereference_iterator`.
+ friend detail::indexed_accessor_range_base<ValueRange, OwnerT, Value *,
+ Value *, Value *>;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 933bf104506..3d25140b43c 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -165,14 +165,19 @@ private:
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
-class RegionRange {
+class RegionRange
+ : public detail::indexed_accessor_range_base<
+ RegionRange,
+ llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>,
+ Region *, Region *, Region *> {
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
public:
- RegionRange(const RegionRange &) = default;
- RegionRange(RegionRange &&) = default;
+ using detail::indexed_accessor_range_base<
+ RegionRange, OwnerT, Region *, Region *,
+ Region *>::indexed_accessor_range_base;
RegionRange(MutableArrayRef<Region> regions = llvm::None);
@@ -184,33 +189,15 @@ public:
}
RegionRange(ArrayRef<std::unique_ptr<Region>> regions);
- /// An iterator element of this range.
- class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Region *,
- Region *, Region *> {
- public:
- Region *operator*() const;
-
- private:
- Iterator(OwnerT owner, unsigned curIndex);
- /// Allow access to the constructor.
- friend RegionRange;
- };
- Iterator begin() const { return Iterator(owner, 0); }
- Iterator end() const { return Iterator(owner, count); }
- Region *operator[](unsigned index) const {
- assert(index < size() && "invalid index for region range");
- return *std::next(begin(), index);
- }
- /// Return the size of this range.
- size_t size() const { return count; }
- /// Return if the range is empty.
- bool empty() const { return size() == 0; }
-
private:
- /// The object that owns the provided range of regions.
- OwnerT owner;
- /// The size from the owning range.
- unsigned count;
+ /// See `detail::indexed_accessor_range_base` for details.
+ static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
+ /// See `detail::indexed_accessor_range_base` for details.
+ static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
+
+ /// Allow access to `offset_base` and `dereference_iterator`.
+ friend detail::indexed_accessor_range_base<RegionRange, OwnerT, Region *,
+ Region *, Region *>;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h
index 95e52f94e22..07db06ae6a3 100644
--- a/mlir/include/mlir/Support/STLExtras.h
+++ b/mlir/include/mlir/Support/STLExtras.h
@@ -147,9 +147,9 @@ using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
// Extra additions to <iterator>
//===----------------------------------------------------------------------===//
-/// A utility class used to implement an iterator that contains some object and
-/// an index. The iterator moves the index but keeps the object constant.
-template <typename DerivedT, typename ObjectType, typename T,
+/// A utility class used to implement an iterator that contains some base object
+/// and an index. The iterator moves the index but keeps the base constant.
+template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_iterator
: public llvm::iterator_facade_base<DerivedT,
@@ -157,14 +157,14 @@ class indexed_accessor_iterator
std::ptrdiff_t, PointerT, ReferenceT> {
public:
ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
- assert(object == rhs.object && "incompatible iterators");
+ assert(base == rhs.base && "incompatible iterators");
return index - rhs.index;
}
bool operator==(const indexed_accessor_iterator &rhs) const {
- return object == rhs.object && index == rhs.index;
+ return base == rhs.base && index == rhs.index;
}
bool operator<(const indexed_accessor_iterator &rhs) const {
- assert(object == rhs.object && "incompatible iterators");
+ assert(base == rhs.base && "incompatible iterators");
return index < rhs.index;
}
@@ -180,16 +180,134 @@ public:
/// Returns the current index of the iterator.
ptrdiff_t getIndex() const { return index; }
- /// Returns the current object of the iterator.
- const ObjectType &getObject() const { return object; }
+ /// Returns the current base of the iterator.
+ const BaseT &getBase() const { return base; }
protected:
- indexed_accessor_iterator(ObjectType object, ptrdiff_t index)
- : object(object), index(index) {}
- ObjectType object;
+ indexed_accessor_iterator(BaseT base, ptrdiff_t index)
+ : base(base), index(index) {}
+ BaseT base;
ptrdiff_t index;
};
+namespace detail {
+/// The class represents the base of a range of indexed_accessor_iterators. It
+/// provides support for many different range functionalities, e.g.
+/// drop_front/slice/etc.. Derived range classes must implement the following
+/// static methods:
+/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
+/// - Derefence an iterator pointing to the base object at the given index.
+/// * BaseT offset_base(const BaseT &base, ptrdiff_t index)
+/// - Return a new base that is offset from the provide base by 'index'
+/// elements.
+template <typename DerivedT, typename BaseT, typename T,
+ typename PointerT = T *, typename ReferenceT = T &>
+class indexed_accessor_range_base {
+public:
+ /// An iterator element of this range.
+ class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
+ PointerT, ReferenceT> {
+ public:
+ // Index into this iterator, invoking a static method on the derived type.
+ ReferenceT operator*() const {
+ return DerivedT::dereference_iterator(this->getBase(), this->getIndex());
+ }
+
+ private:
+ iterator(BaseT owner, ptrdiff_t curIndex)
+ : indexed_accessor_iterator<iterator, BaseT, T, PointerT, ReferenceT>(
+ owner, curIndex) {}
+
+ /// Allow access to the constructor.
+ friend indexed_accessor_range_base<DerivedT, BaseT, T, PointerT,
+ ReferenceT>;
+ };
+
+ iterator begin() const { return iterator(base, 0); }
+ iterator end() const { return iterator(base, count); }
+ ReferenceT operator[](unsigned index) const {
+ assert(index < size() && "invalid index for value range");
+ return *std::next(begin(), index);
+ }
+
+ /// Return the size of this range.
+ size_t size() const { return count; }
+
+ /// Return if the range is empty.
+ bool empty() const { return size() == 0; }
+
+ /// Drop the first N elements, and keep M elements.
+ DerivedT slice(unsigned n, unsigned m) const {
+ assert(n + m <= size() && "invalid size specifiers");
+ return DerivedT(DerivedT::offset_base(base, n), m);
+ }
+
+ /// Drop the first n elements.
+ DerivedT drop_front(unsigned n = 1) const {
+ assert(size() >= n && "Dropping more elements than exist");
+ return slice(n, size() - n);
+ }
+ /// Drop the last n elements.
+ DerivedT drop_back(unsigned n = 1) const {
+ assert(size() >= n && "Dropping more elements than exist");
+ return DerivedT(base, size() - n);
+ }
+
+protected:
+ indexed_accessor_range_base(BaseT base, ptrdiff_t count)
+ : base(base), count(count) {}
+ indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
+ indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
+ indexed_accessor_range_base &
+ operator=(const indexed_accessor_range_base &) = default;
+
+ /// The base that owns the provided range of values.
+ BaseT base;
+ /// The size from the owning range.
+ ptrdiff_t count;
+};
+} // end namespace detail
+
+/// This class provides an implementation of a range of
+/// indexed_accessor_iterators where the base is not indexable. Ranges with
+/// bases that are offsetable should derive from indexed_accessor_range_base
+/// instead. Derived range classes are expected to implement the following
+/// static method:
+/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
+/// - Derefence an iterator pointing to a parent base at the given index.
+template <typename DerivedT, typename BaseT, typename T,
+ typename PointerT = T *, typename ReferenceT = T &>
+class indexed_accessor_range
+ : public detail::indexed_accessor_range_base<
+ indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
+ std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
+protected:
+ indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
+ : detail::indexed_accessor_range_base<
+ DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
+ std::make_pair(base, startIndex), count) {}
+
+private:
+ /// See `detail::indexed_accessor_range_base` for details.
+ static std::pair<BaseT, ptrdiff_t>
+ offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
+ // We encode the internal base as a pair of the derived base and a start
+ // index into the derived base.
+ return std::make_pair(base.first, base.second + index);
+ }
+ /// See `detail::indexed_accessor_range_base` for details.
+ static ReferenceT
+ dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
+ ptrdiff_t index) {
+ return DerivedT::dereference_iterator(base.first, base.second + index);
+ }
+
+ /// Allow access to `offset_base` and `dereference_iterator`.
+ friend detail::indexed_accessor_range_base<
+ indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
+ std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>;
+};
+
/// Given a container of pairs, return a range over the second elements.
template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
return llvm::map_range(
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index f2f3d41f980..b546643837b 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -527,7 +527,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
/// Accesses the Attribute value at this iterator position.
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
- auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>();
+ auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
Type eltTy = owner.getType().getElementType();
if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
if (intEltTy.getWidth() == 1)
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index ae635d108b2..0483c27e968 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -750,60 +750,41 @@ Operation *Operation::clone() {
//===----------------------------------------------------------------------===//
ValueRange::ValueRange(ArrayRef<Value *> values)
- : owner(values.data()), count(values.size()) {}
+ : ValueRange(values.data(), values.size()) {}
ValueRange::ValueRange(llvm::iterator_range<OperandIterator> values)
- : count(llvm::size(values)) {
- if (count != 0) {
+ : ValueRange(nullptr, llvm::size(values)) {
+ if (!empty()) {
auto begin = values.begin();
- owner = &begin.getObject()->getOpOperand(begin.getIndex());
+ base = &begin.getBase()->getOpOperand(begin.getIndex());
}
}
ValueRange::ValueRange(llvm::iterator_range<ResultIterator> values)
- : count(llvm::size(values)) {
- if (count != 0) {
+ : ValueRange(nullptr, llvm::size(values)) {
+ if (!empty()) {
auto begin = values.begin();
- owner = &begin.getObject()->getOpResult(begin.getIndex());
+ base = &begin.getBase()->getOpResult(begin.getIndex());
}
}
-/// Drop the first N elements, and keep M elements.
-ValueRange ValueRange::slice(unsigned n, unsigned m) const {
- assert(n + m <= size() && "Invalid specifier");
- OwnerT newOwner;
+/// See `detail::indexed_accessor_range_base` for details.
+ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
+ ptrdiff_t index) {
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
- newOwner = operand + n;
- else if (OpResult *result = owner.dyn_cast<OpResult *>())
- newOwner = result + n;
- else
- newOwner = owner.get<Value *const *>() + n;
- return ValueRange(newOwner, m);
-}
-
-/// Drop the first n elements.
-ValueRange ValueRange::drop_front(unsigned n) const {
- assert(size() >= n && "Dropping more elements than exist");
- return slice(n, size() - n);
-}
-
-/// Drop the last n elements.
-ValueRange ValueRange::drop_back(unsigned n) const {
- assert(size() >= n && "Dropping more elements than exist");
- return ValueRange(owner, size() - n);
+ return operand + index;
+ if (OpResult *result = owner.dyn_cast<OpResult *>())
+ return result + index;
+ return owner.get<Value *const *>() + index;
}
-
-ValueRange::Iterator::Iterator(OwnerT owner, unsigned curIndex)
- : indexed_accessor_iterator<Iterator, OwnerT, Value *, Value *, Value *>(
- owner, curIndex) {}
-
-Value *ValueRange::Iterator::operator*() const {
+/// See `detail::indexed_accessor_range_base` for details.
+Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
// Operands access the held value via 'get'.
- if (OpOperand *operand = object.dyn_cast<OpOperand *>())
+ if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand[index].get();
// An OpResult is a value, so we can return it directly.
- if (OpResult *result = object.dyn_cast<OpResult *>())
+ if (OpResult *result = owner.dyn_cast<OpResult *>())
return &result[index];
// Otherwise, this is a raw value array so just index directly.
- return object.get<Value *const *>()[index];
+ return owner.get<Value *const *>()[index];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index a5a19cbcbf2..c588e567bc3 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -217,17 +217,23 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
//===----------------------------------------------------------------------===//
// RegionRange
//===----------------------------------------------------------------------===//
+
RegionRange::RegionRange(MutableArrayRef<Region> regions)
- : owner(regions.data()), count(regions.size()) {}
+ : RegionRange(regions.data(), regions.size()) {}
RegionRange::RegionRange(ArrayRef<std::unique_ptr<Region>> regions)
- : owner(regions.data()), count(regions.size()) {}
-RegionRange::Iterator::Iterator(OwnerT owner, unsigned curIndex)
- : indexed_accessor_iterator<Iterator, OwnerT, Region *, Region *, Region *>(
- owner, curIndex) {}
-
-Region *RegionRange::Iterator::operator*() const {
- if (const std::unique_ptr<Region> *operand =
- object.dyn_cast<const std::unique_ptr<Region> *>())
+ : RegionRange(regions.data(), regions.size()) {}
+
+/// See `detail::indexed_accessor_range_base` for details.
+RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
+ ptrdiff_t index) {
+ if (auto *operand = owner.dyn_cast<const std::unique_ptr<Region> *>())
+ return operand + index;
+ return &owner.get<Region *>()[index];
+}
+/// See `detail::indexed_accessor_range_base` for details.
+Region *RegionRange::dereference_iterator(const OwnerT &owner,
+ ptrdiff_t index) {
+ if (auto *operand = owner.dyn_cast<const std::unique_ptr<Region> *>())
return operand[index].get();
- return &object.get<Region *>()[index];
+ return &owner.get<Region *>()[index];
}
OpenPOWER on IntegriCloud