diff options
-rw-r--r-- | llvm/include/llvm/ADT/SmallSet.h | 25 | ||||
-rw-r--r-- | llvm/unittests/ADT/SmallSetTest.cpp | 25 |
2 files changed, 50 insertions, 0 deletions
diff --git a/llvm/include/llvm/ADT/SmallSet.h b/llvm/include/llvm/ADT/SmallSet.h index 6b128c2e299..a03fa7dd842 100644 --- a/llvm/include/llvm/ADT/SmallSet.h +++ b/llvm/include/llvm/ADT/SmallSet.h @@ -248,6 +248,31 @@ private: template <typename PointeeType, unsigned N> class SmallSet<PointeeType*, N> : public SmallPtrSet<PointeeType*, N> {}; +/// Equality comparison for SmallSet. +/// +/// Iterates over elements of LHS confirming that each element is also a member +/// of RHS, and that RHS contains no additional values. +/// Equivalent to N calls to RHS.count. +/// For small-set mode amortized complexity is O(N^2) +/// For large-set mode amortized complexity is linear, worst case is O(N^2) (if +/// every hash collides). +template <typename T, unsigned LN, unsigned RN, typename C> +bool operator==(const SmallSet<T, LN, C> &LHS, const SmallSet<T, RN, C> &RHS) { + if (LHS.size() != RHS.size()) + return false; + + // All elements in LHS must also be in RHS + return all_of(LHS, [&RHS](const T &E) { return RHS.count(E); }); +} + +/// Inequality comparison for SmallSet. +/// +/// Equivalent to !(LHS == RHS). See operator== for performance notes. +template <typename T, unsigned LN, unsigned RN, typename C> +bool operator!=(const SmallSet<T, LN, C> &LHS, const SmallSet<T, RN, C> &RHS) { + return !(LHS == RHS); +} + } // end namespace llvm #endif // LLVM_ADT_SMALLSET_H diff --git a/llvm/unittests/ADT/SmallSetTest.cpp b/llvm/unittests/ADT/SmallSetTest.cpp index 8fb78b01f44..06682ce823d 100644 --- a/llvm/unittests/ADT/SmallSetTest.cpp +++ b/llvm/unittests/ADT/SmallSetTest.cpp @@ -142,3 +142,28 @@ TEST(SmallSetTest, IteratorIncMoveCopy) { Iter = std::move(Iter2); EXPECT_EQ("str 0", *Iter); } + +TEST(SmallSetTest, EqualityComparisonTest) { + SmallSet<int, 8> s1small; + SmallSet<int, 10> s2small; + SmallSet<int, 3> s3large; + SmallSet<int, 8> s4large; + + for (int i = 1; i < 5; i++) { + s1small.insert(i); + s2small.insert(5 - i); + s3large.insert(i); + } + for (int i = 1; i < 11; i++) + s4large.insert(i); + + EXPECT_EQ(s1small, s1small); + EXPECT_EQ(s3large, s3large); + + EXPECT_EQ(s1small, s2small); + EXPECT_EQ(s1small, s3large); + EXPECT_EQ(s2small, s3large); + + EXPECT_NE(s1small, s4large); + EXPECT_NE(s4large, s3large); +} |