diff options
-rw-r--r-- | llvm/include/llvm/ADT/STLExtras.h | 72 | ||||
-rw-r--r-- | llvm/unittests/ADT/STLExtrasTest.cpp | 49 |
2 files changed, 121 insertions, 0 deletions
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index e6215e4ae5b..8d1ae049c08 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -626,6 +626,78 @@ template <typename T> struct deref { } }; +namespace detail { +template <typename I, typename V> class enumerator_impl { +public: + template <typename X> struct result_pair { + result_pair(std::size_t Index, X Value) : Index(Index), Value(Value) {} + + const std::size_t Index; + X Value; + }; + + struct iterator { + iterator(I Iter, std::size_t Index) : Iter(Iter), Index(Index) {} + + result_pair<const V> operator*() const { + return result_pair<const V>(Index, *Iter); + } + result_pair<V> operator*() { return result_pair<V>(Index, *Iter); } + + iterator &operator++() { + ++Iter; + ++Index; + return *this; + } + + bool operator!=(const iterator &RHS) const { return Iter != RHS.Iter; } + + private: + I Iter; + std::size_t Index; + }; + + enumerator_impl(I Begin, I End) + : Begin(std::move(Begin)), End(std::move(End)) {} + + iterator begin() { return iterator(Begin, 0); } + iterator end() { return iterator(End, std::size_t(-1)); } + + iterator begin() const { return iterator(Begin, 0); } + iterator end() const { return iterator(End, std::size_t(-1)); } + +private: + I Begin; + I End; +}; + +template <typename I> +auto make_enumerator(I Begin, I End) -> enumerator_impl<I, decltype(*Begin)> { + return enumerator_impl<I, decltype(*Begin)>(std::move(Begin), std::move(End)); +} +} + +/// Given an input range, returns a new range whose values are are pair (A,B) +/// such that A is the 0-based index of the item in the sequence, and B is +/// the value from the original sequence. Example: +/// +/// std::vector<char> Items = {'A', 'B', 'C', 'D'}; +/// for (auto X : enumerate(Items)) { +/// printf("Item %d - %c\n", X.Item, X.Value); +/// } +/// +/// Output: +/// Item 0 - A +/// Item 1 - B +/// Item 2 - C +/// Item 3 - D +/// +template <typename R> +auto enumerate(R &&Range) + -> decltype(detail::make_enumerator(std::begin(Range), std::end(Range))) { + return detail::make_enumerator(std::begin(Range), std::end(Range)); +} + } // End llvm namespace #endif diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index dc62b03741c..ebb119600c7 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -10,6 +10,8 @@ #include "llvm/ADT/STLExtras.h" #include "gtest/gtest.h" +#include <vector> + using namespace llvm; namespace { @@ -37,4 +39,51 @@ TEST(STLExtrasTest, Rank) { EXPECT_EQ(4, f(rank<6>())); } +TEST(STLExtrasTest, Enumerate) { + std::vector<char> foo = {'a', 'b', 'c'}; + + std::vector<std::pair<std::size_t, char>> results; + + for (auto X : llvm::enumerate(foo)) { + results.push_back(std::make_pair(X.Index, X.Value)); + } + ASSERT_EQ(3u, results.size()); + EXPECT_EQ(0u, results[0].first); + EXPECT_EQ('a', results[0].second); + EXPECT_EQ(1u, results[1].first); + EXPECT_EQ('b', results[1].second); + EXPECT_EQ(2u, results[2].first); + EXPECT_EQ('c', results[2].second); + + results.clear(); + const std::vector<int> bar = {'1', '2', '3'}; + for (auto X : llvm::enumerate(bar)) { + results.push_back(std::make_pair(X.Index, X.Value)); + } + EXPECT_EQ(0u, results[0].first); + EXPECT_EQ('1', results[0].second); + EXPECT_EQ(1u, results[1].first); + EXPECT_EQ('2', results[1].second); + EXPECT_EQ(2u, results[2].first); + EXPECT_EQ('3', results[2].second); + + results.clear(); + const std::vector<int> baz; + for (auto X : llvm::enumerate(baz)) { + results.push_back(std::make_pair(X.Index, X.Value)); + } + EXPECT_TRUE(baz.empty()); +} + +TEST(STLExtrasTest, EnumerateModify) { + std::vector<char> foo = {'a', 'b', 'c'}; + + for (auto X : llvm::enumerate(foo)) { + ++X.Value; + } + + EXPECT_EQ('b', foo[0]); + EXPECT_EQ('c', foo[1]); + EXPECT_EQ('d', foo[2]); +} } |