diff options
| -rw-r--r-- | llvm/include/llvm/ADT/SmallPtrSet.h | 45 | ||||
| -rw-r--r-- | llvm/unittests/ADT/SmallPtrSetTest.cpp | 32 |
2 files changed, 48 insertions, 29 deletions
diff --git a/llvm/include/llvm/ADT/SmallPtrSet.h b/llvm/include/llvm/ADT/SmallPtrSet.h index 7234f0fbded..b98cf6c376b 100644 --- a/llvm/include/llvm/ADT/SmallPtrSet.h +++ b/llvm/include/llvm/ADT/SmallPtrSet.h @@ -260,11 +260,10 @@ protected: } #if LLVM_ENABLE_ABI_BREAKING_CHECKS void RetreatIfNotValid() { - --Bucket; - assert(Bucket <= End); + assert(Bucket >= End); while (Bucket != End && - (*Bucket == SmallPtrSetImplBase::getEmptyMarker() || - *Bucket == SmallPtrSetImplBase::getTombstoneMarker())) { + (Bucket[-1] == SmallPtrSetImplBase::getEmptyMarker() || + Bucket[-1] == SmallPtrSetImplBase::getTombstoneMarker())) { --Bucket; } } @@ -289,6 +288,12 @@ public: // Most methods provided by baseclass. const PtrTy operator*() const { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (ReverseIterate<bool>::value) { + assert(Bucket > End); + return PtrTraits::getFromVoidPointer(const_cast<void *>(Bucket[-1])); + } +#endif assert(Bucket < End); return PtrTraits::getFromVoidPointer(const_cast<void*>(*Bucket)); } @@ -296,6 +301,7 @@ public: inline SmallPtrSetIterator& operator++() { // Preincrement #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate<bool>::value) { + --Bucket; RetreatIfNotValid(); return *this; } @@ -370,7 +376,7 @@ public: /// the element equal to Ptr. std::pair<iterator, bool> insert(PtrType Ptr) { auto p = insert_imp(PtrTraits::getAsVoidPointer(Ptr)); - return std::make_pair(iterator(p.first, EndPointer()), p.second); + return std::make_pair(makeIterator(p.first), p.second); } /// erase - If the set contains the specified pointer, remove it and return @@ -379,12 +385,9 @@ public: return erase_imp(PtrTraits::getAsVoidPointer(Ptr)); } /// count - Return 1 if the specified pointer is in the set, 0 otherwise. - size_type count(ConstPtrType Ptr) const { - return find(Ptr) != endPtr() ? 1 : 0; - } + size_type count(ConstPtrType Ptr) const { return find(Ptr) != end() ? 1 : 0; } iterator find(ConstPtrType Ptr) const { - auto *P = find_imp(ConstPtrTraits::getAsVoidPointer(Ptr)); - return iterator(P, EndPointer()); + return makeIterator(find_imp(ConstPtrTraits::getAsVoidPointer(Ptr))); } template <typename IterT> @@ -397,25 +400,23 @@ public: insert(IL.begin(), IL.end()); } - inline iterator begin() const { + iterator begin() const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate<bool>::value) - return endPtr(); + return makeIterator(EndPointer() - 1); #endif - return iterator(CurArray, EndPointer()); + return makeIterator(CurArray); } - inline iterator end() const { + iterator end() const { return makeIterator(EndPointer()); } + +private: + /// Create an iterator that dereferences to same place as the given pointer. + iterator makeIterator(const void *const *P) const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate<bool>::value) - return iterator(CurArray, CurArray); + return iterator(P == EndPointer() ? CurArray : P + 1, CurArray); #endif - return endPtr(); - } - -private: - inline iterator endPtr() const { - const void *const *End = EndPointer(); - return iterator(End, End); + return iterator(P, EndPointer()); } }; diff --git a/llvm/unittests/ADT/SmallPtrSetTest.cpp b/llvm/unittests/ADT/SmallPtrSetTest.cpp index bb9ee67b7eb..fc14c684d67 100644 --- a/llvm/unittests/ADT/SmallPtrSetTest.cpp +++ b/llvm/unittests/ADT/SmallPtrSetTest.cpp @@ -282,6 +282,28 @@ TEST(SmallPtrSetTest, EraseTest) { checkEraseAndIterators(A); } +// Verify that dereferencing and iteration work. +TEST(SmallPtrSetTest, dereferenceAndIterate) { + int Ints[] = {0, 1, 2, 3, 4, 5, 6, 7}; + SmallPtrSet<const int *, 4> S; + for (int &I : Ints) { + EXPECT_EQ(&I, *S.insert(&I).first); + EXPECT_EQ(&I, *S.find(&I)); + } + + // Iterate from each and count how many times each element is found. + int Found[sizeof(Ints)/sizeof(int)] = {0}; + for (int &I : Ints) + for (auto F = S.find(&I), E = S.end(); F != E; ++F) + ++Found[*F - Ints]; + + // Sort. We should hit the first element just once and the final element N + // times. + std::sort(std::begin(Found), std::end(Found)); + for (auto F = std::begin(Found), E = std::end(Found); F != E; ++F) + EXPECT_EQ(F - Found + 1, *F); +} + // Verify that const pointers work for count and find even when the underlying // SmallPtrSet is not for a const pointer type. TEST(SmallPtrSetTest, ConstTest) { @@ -292,10 +314,8 @@ TEST(SmallPtrSetTest, ConstTest) { IntSet.insert(B); EXPECT_EQ(IntSet.count(B), 1u); EXPECT_EQ(IntSet.count(C), 1u); - // FIXME: We can't unit test find right now because ABI_BREAKING_CHECKS breaks - // find(). - // EXPECT_NE(IntSet.find(B), IntSet.end()); - // EXPECT_NE(IntSet.find(C), IntSet.end()); + EXPECT_NE(IntSet.find(B), IntSet.end()); + EXPECT_NE(IntSet.find(C), IntSet.end()); } // Verify that we automatically get the const version of PointerLikeTypeTraits @@ -308,7 +328,5 @@ TEST(SmallPtrSetTest, ConstNonPtrTest) { TestPair Pair(&A[0], 1); IntSet.insert(Pair); EXPECT_EQ(IntSet.count(Pair), 1u); - // FIXME: We can't unit test find right now because ABI_BREAKING_CHECKS breaks - // find(). - // EXPECT_NE(IntSet.find(Pair), IntSet.end()); + EXPECT_NE(IntSet.find(Pair), IntSet.end()); } |

