summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/ADT/STLExtras.h110
-rw-r--r--llvm/unittests/ADT/STLExtrasTest.cpp24
-rw-r--r--llvm/utils/TableGen/GlobalISelEmitter.cpp4
3 files changed, 98 insertions, 40 deletions
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 0389e3da4b1..111c7a1cb3b 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -28,6 +28,7 @@
#include <utility> // for std::pair
#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Compiler.h"
@@ -44,6 +45,10 @@ namespace detail {
template <typename RangeT>
using IterOfRange = decltype(std::begin(std::declval<RangeT &>()));
+template <typename RangeT>
+using ValueOfRange = typename std::remove_reference<decltype(
+ *std::begin(std::declval<RangeT &>()))>::type;
+
} // End detail namespace
//===----------------------------------------------------------------------===//
@@ -883,6 +888,14 @@ auto partition(R &&Range, UnaryPredicate P) -> decltype(std::begin(Range)) {
return std::partition(std::begin(Range), std::end(Range), P);
}
+/// \brief Given a range of type R, iterate the entire range and return a
+/// SmallVector with elements of the vector. This is useful, for example,
+/// when you want to iterate a range and then sort the results.
+template <unsigned Size, typename R>
+SmallVector<detail::ValueOfRange<R>, Size> to_vector(R &&Range) {
+ return {std::begin(Range), std::end(Range)};
+}
+
/// Provide a container algorithm similar to C++ Library Fundamentals v2's
/// `erase_if` which is equivalent to:
///
@@ -977,47 +990,82 @@ template <typename T> struct deref {
};
namespace detail {
-template <typename R> class enumerator_impl {
-public:
- template <typename X> struct result_pair {
- result_pair(std::size_t Index, X Value) : Index(Index), Value(Value) {}
+template <typename R> class enumerator_iter;
- const std::size_t Index;
- X Value;
- };
+template <typename R> struct result_pair {
+ friend class enumerator_iter<R>;
- class iterator {
- typedef
- typename std::iterator_traits<IterOfRange<R>>::reference iter_reference;
- typedef result_pair<iter_reference> result_type;
+ result_pair() : Index(-1) {}
+ result_pair(std::size_t Index, IterOfRange<R> Iter)
+ : Index(Index), Iter(Iter) {}
- public:
- iterator(IterOfRange<R> &&Iter, std::size_t Index)
- : Iter(Iter), Index(Index) {}
+ result_pair<R> &operator=(const result_pair<R> &Other) {
+ Index = Other.Index;
+ Iter = Other.Iter;
+ return *this;
+ }
+
+ std::size_t index() const { return Index; }
+ const ValueOfRange<R> &value() const { return *Iter; }
+ ValueOfRange<R> &value() { return *Iter; }
- result_type operator*() const { return result_type(Index, *Iter); }
+private:
+ std::size_t Index;
+ IterOfRange<R> Iter;
+};
- iterator &operator++() {
- ++Iter;
- ++Index;
- return *this;
- }
+template <typename R>
+class enumerator_iter
+ : public iterator_facade_base<
+ enumerator_iter<R>, std::forward_iterator_tag, result_pair<R>,
+ typename std::iterator_traits<IterOfRange<R>>::difference_type,
+ typename std::iterator_traits<IterOfRange<R>>::pointer,
+ typename std::iterator_traits<IterOfRange<R>>::reference> {
+ using result_type = result_pair<R>;
- bool operator!=(const iterator &RHS) const { return Iter != RHS.Iter; }
+public:
+ enumerator_iter(std::size_t Index, IterOfRange<R> Iter)
+ : Result(Index, Iter) {}
- private:
- IterOfRange<R> Iter;
- std::size_t Index;
- };
+ result_type &operator*() { return Result; }
+ const result_type &operator*() const { return Result; }
+ enumerator_iter<R> &operator++() {
+ assert(Result.Index != -1);
+ ++Result.Iter;
+ ++Result.Index;
+ return *this;
+ }
+
+ bool operator==(const enumerator_iter<R> &RHS) const {
+ // Don't compare indices here, only iterators. It's possible for an end
+ // iterator to have different indices depending on whether it was created
+ // by calling std::end() versus incrementing a valid iterator.
+ return Result.Iter == RHS.Result.Iter;
+ }
+
+ enumerator_iter<R> &operator=(const enumerator_iter<R> &Other) {
+ Result = Other.Result;
+ return *this;
+ }
+
+private:
+ result_type Result;
+};
+
+template <typename R> class enumerator {
public:
- explicit enumerator_impl(R &&Range) : Range(std::forward<R>(Range)) {}
+ explicit enumerator(R &&Range) : TheRange(std::forward<R>(Range)) {}
- iterator begin() { return iterator(std::begin(Range), 0); }
- iterator end() { return iterator(std::end(Range), std::size_t(-1)); }
+ enumerator_iter<R> begin() {
+ return enumerator_iter<R>(0, std::begin(TheRange));
+ }
+ enumerator_iter<R> end() {
+ return enumerator_iter<R>(-1, std::end(TheRange));
+ }
private:
- R Range;
+ R TheRange;
};
}
@@ -1036,8 +1084,8 @@ private:
/// Item 2 - C
/// Item 3 - D
///
-template <typename R> detail::enumerator_impl<R> enumerate(R &&Range) {
- return detail::enumerator_impl<R>(std::forward<R>(Range));
+template <typename R> detail::enumerator<R> enumerate(R &&TheRange) {
+ return detail::enumerator<R>(std::forward<R>(TheRange));
}
namespace detail {
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 6c162de6d37..dabd786a9ad 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -48,7 +48,7 @@ TEST(STLExtrasTest, EnumerateLValue) {
std::vector<CharPairType> CharResults;
for (auto X : llvm::enumerate(foo)) {
- CharResults.emplace_back(X.Index, X.Value);
+ CharResults.emplace_back(X.index(), X.value());
}
ASSERT_EQ(3u, CharResults.size());
EXPECT_EQ(CharPairType(0u, 'a'), CharResults[0]);
@@ -60,7 +60,7 @@ TEST(STLExtrasTest, EnumerateLValue) {
std::vector<IntPairType> IntResults;
const std::vector<int> bar = {1, 2, 3};
for (auto X : llvm::enumerate(bar)) {
- IntResults.emplace_back(X.Index, X.Value);
+ IntResults.emplace_back(X.index(), X.value());
}
ASSERT_EQ(3u, IntResults.size());
EXPECT_EQ(IntPairType(0u, 1), IntResults[0]);
@@ -71,7 +71,7 @@ TEST(STLExtrasTest, EnumerateLValue) {
IntResults.clear();
const std::vector<int> baz{};
for (auto X : llvm::enumerate(baz)) {
- IntResults.emplace_back(X.Index, X.Value);
+ IntResults.emplace_back(X.index(), X.value());
}
EXPECT_TRUE(IntResults.empty());
}
@@ -82,7 +82,7 @@ TEST(STLExtrasTest, EnumerateModifyLValue) {
std::vector<char> foo = {'a', 'b', 'c'};
for (auto X : llvm::enumerate(foo)) {
- ++X.Value;
+ ++X.value();
}
EXPECT_EQ('b', foo[0]);
EXPECT_EQ('c', foo[1]);
@@ -97,7 +97,7 @@ TEST(STLExtrasTest, EnumerateRValueRef) {
auto Enumerator = llvm::enumerate(std::vector<int>{1, 2, 3});
for (auto X : llvm::enumerate(std::vector<int>{1, 2, 3})) {
- Results.emplace_back(X.Index, X.Value);
+ Results.emplace_back(X.index(), X.value());
}
ASSERT_EQ(3u, Results.size());
@@ -114,8 +114,8 @@ TEST(STLExtrasTest, EnumerateModifyRValue) {
std::vector<PairType> Results;
for (auto X : llvm::enumerate(std::vector<char>{'1', '2', '3'})) {
- ++X.Value;
- Results.emplace_back(X.Index, X.Value);
+ ++X.value();
+ Results.emplace_back(X.index(), X.value());
}
ASSERT_EQ(3u, Results.size());
@@ -255,6 +255,16 @@ TEST(STLExtrasTest, CountAdaptor) {
EXPECT_EQ(1, count(v, 4));
}
+TEST(STLExtrasTest, ToVector) {
+ std::vector<char> v = {'a', 'b', 'c'};
+ auto Enumerated = to_vector<4>(enumerate(v));
+ ASSERT_EQ(3, Enumerated.size());
+ for (size_t I = 0; I < v.size(); ++I) {
+ EXPECT_EQ(I, Enumerated[I].index());
+ EXPECT_EQ(v[I], Enumerated[I].value());
+ }
+}
+
TEST(STLExtrasTest, ConcatRange) {
std::vector<int> Expected = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int> Test;
diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp
index 3df189bcf6e..58504cdb8e3 100644
--- a/llvm/utils/TableGen/GlobalISelEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp
@@ -582,9 +582,9 @@ private:
/// True if the instruction can be built solely by mutating the opcode.
bool canMutate() const {
for (const auto &Renderer : enumerate(OperandRenderers)) {
- if (const auto *Copy = dyn_cast<CopyRenderer>(&*Renderer.Value)) {
+ if (const auto *Copy = dyn_cast<CopyRenderer>(&*Renderer.value())) {
if (Matched.getOperand(Copy->getSymbolicName()).getOperandIndex() !=
- Renderer.Index)
+ Renderer.index())
return false;
} else
return false;
OpenPOWER on IntegriCloud