diff options
Diffstat (limited to 'lldb/unittests/Host/SocketTest.cpp')
-rw-r--r-- | lldb/unittests/Host/SocketTest.cpp | 80 |
1 files changed, 75 insertions, 5 deletions
diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index 939a164a9ce..e3e52274476 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -19,9 +19,9 @@ #include "gtest/gtest.h" -#include "SocketUtil.h" - #include "lldb/Host/Config.h" +#include "lldb/Host/Socket.h" +#include "lldb/Host/common/TCPSocket.h" #include "lldb/Host/common/UDPSocket.h" #ifndef LLDB_DISABLE_POSIX @@ -49,6 +49,52 @@ class SocketTest : public testing::Test ::WSACleanup(); #endif } + + protected: + static void + AcceptThread(Socket *listen_socket, const char *listen_remote_address, bool child_processes_inherit, + Socket **accept_socket, Error *error) + { + *error = listen_socket->Accept(listen_remote_address, child_processes_inherit, *accept_socket); + } + + template<typename SocketType> + void + CreateConnectedSockets(const char *listen_remote_address, const std::function<std::string(const SocketType&)> &get_connect_addr, std::unique_ptr<SocketType> *a_up, std::unique_ptr<SocketType> *b_up) + { + bool child_processes_inherit = false; + Error error; + std::unique_ptr<SocketType> listen_socket_up(new SocketType(child_processes_inherit, error)); + EXPECT_FALSE(error.Fail()); + error = listen_socket_up->Listen(listen_remote_address, 5); + EXPECT_FALSE(error.Fail()); + EXPECT_TRUE(listen_socket_up->IsValid()); + + Error accept_error; + Socket *accept_socket; + std::thread accept_thread(AcceptThread, listen_socket_up.get(), listen_remote_address, child_processes_inherit, + &accept_socket, &accept_error); + + std::string connect_remote_address = get_connect_addr(*listen_socket_up); + std::unique_ptr<SocketType> connect_socket_up(new SocketType(child_processes_inherit, error)); + EXPECT_FALSE(error.Fail()); + error = connect_socket_up->Connect(connect_remote_address.c_str()); + EXPECT_FALSE(error.Fail()); + EXPECT_TRUE(connect_socket_up->IsValid()); + + a_up->swap(connect_socket_up); + EXPECT_TRUE(error.Success()); + EXPECT_NE(nullptr, a_up->get()); + EXPECT_TRUE((*a_up)->IsValid()); + + accept_thread.join(); + b_up->reset(static_cast<SocketType*>(accept_socket)); + EXPECT_TRUE(accept_error.Success()); + EXPECT_NE(nullptr, b_up->get()); + EXPECT_TRUE((*b_up)->IsValid()); + + listen_socket_up.reset(); + } }; TEST_F (SocketTest, DecodeHostAndPort) @@ -102,20 +148,44 @@ TEST_F (SocketTest, DomainListenConnectAccept) const std::string file_name(file_name_str); free(file_name_str); - CreateConnectedSockets<DomainSocket>(file_name.c_str(), [=](const DomainSocket &) { return file_name; }); + std::unique_ptr<DomainSocket> socket_a_up; + std::unique_ptr<DomainSocket> socket_b_up; + CreateConnectedSockets<DomainSocket>(file_name.c_str(), + [=](const DomainSocket &) + { + return file_name; + }, + &socket_a_up, &socket_b_up); } #endif TEST_F (SocketTest, TCPListen0ConnectAccept) { - CreateConnectedTCPSockets(); + std::unique_ptr<TCPSocket> socket_a_up; + std::unique_ptr<TCPSocket> socket_b_up; + CreateConnectedSockets<TCPSocket>("127.0.0.1:0", + [=](const TCPSocket &s) + { + char connect_remote_address[64]; + snprintf(connect_remote_address, sizeof(connect_remote_address), "localhost:%u", s.GetLocalPortNumber()); + return std::string(connect_remote_address); + }, + &socket_a_up, &socket_b_up); } TEST_F (SocketTest, TCPGetAddress) { std::unique_ptr<TCPSocket> socket_a_up; std::unique_ptr<TCPSocket> socket_b_up; - std::tie(socket_a_up, socket_b_up) = CreateConnectedTCPSockets(); + CreateConnectedSockets<TCPSocket>("127.0.0.1:0", + [=](const TCPSocket &s) + { + char connect_remote_address[64]; + snprintf(connect_remote_address, sizeof(connect_remote_address), "localhost:%u", s.GetLocalPortNumber()); + return std::string(connect_remote_address); + }, + &socket_a_up, + &socket_b_up); EXPECT_EQ (socket_a_up->GetLocalPortNumber (), socket_b_up->GetRemotePortNumber ()); EXPECT_EQ (socket_b_up->GetLocalPortNumber (), socket_a_up->GetRemotePortNumber ()); |