From db740fb9fc126515a91a8d39b9130d317259158c Mon Sep 17 00:00:00 2001 From: Matej Ferencevic Date: Fri, 7 Apr 2017 17:09:49 +0200 Subject: [PATCH] Started working on network hang fix. Reviewers: buda Reviewed By: buda Subscribers: pullbot, mislav.bradac Differential Revision: https://phabricator.memgraph.io/D228 --- src/io/network/socket.cpp | 30 ++++- src/io/network/socket.hpp | 12 ++ src/io/network/stream_reader.hpp | 64 ++++++----- tests/concurrent/network_read_hang.cpp | 127 ++++++++++++++++++++++ tests/concurrent/network_server.cpp | 11 +- tests/concurrent/network_session_leak.cpp | 9 +- 6 files changed, 211 insertions(+), 42 deletions(-) create mode 100644 tests/concurrent/network_read_hang.cpp diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp index fd11e8e94..87da79a71 100644 --- a/src/io/network/socket.cpp +++ b/src/io/network/socket.cpp @@ -85,12 +85,21 @@ bool Socket::Bind(NetworkEndpoint& endpoint) { if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) { socket_ = sfd; - endpoint_ = endpoint; break; } } if (socket_ == -1) return false; + + // detect bound port, used when the server binds to a random port + struct sockaddr_in6 portdata; + socklen_t portdatalen = sizeof(portdata); + if (getsockname(socket_, (struct sockaddr *) &portdata, &portdatalen) < 0) { + return false; + } + + endpoint_ = NetworkEndpoint(endpoint.address(), ntohs(portdata.sin6_port)); + return true; } @@ -128,6 +137,20 @@ bool Socket::SetKeepAlive() { return true; } +bool Socket::SetTimeout(long sec, long usec) { + struct timeval tv; + tv.tv_sec = sec; + tv.tv_usec = usec; + + if (setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) + return false; + + if (setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0) + return false; + + return true; +} + bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; } bool Socket::Accept(Socket* s) { @@ -180,7 +203,10 @@ bool Socket::Write(const char* data, size_t len) { bool Socket::Write(const uint8_t* data, size_t len) { while (len > 0) { - auto written = send(socket_, data, len, 0); + // MSG_NOSIGNAL is here to disable raising a SIGPIPE + // signal when a connection dies mid-write, the socket + // will only return an EPIPE error + auto written = send(socket_, data, len, MSG_NOSIGNAL); if (UNLIKELY(written == -1)) return false; len -= written; data += written; diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp index 29ba547be..b37f091d7 100644 --- a/src/io/network/socket.hpp +++ b/src/io/network/socket.hpp @@ -97,6 +97,18 @@ class Socket { */ bool SetKeepAlive(); + /** + * Sets the socket timeout. + * + * @param sec timeout seconds value + * @param usec timeout microseconds value + * @return set socket timeout status: + * true if the timeout was successfully set to + * sec seconds + usec microseconds + * false if the timeout was not set + */ + bool SetTimeout(long sec, long usec); + // TODO: this will be removed operator int(); diff --git a/src/io/network/stream_reader.hpp b/src/io/network/stream_reader.hpp index 96d5c4c06..ec1a4c05f 100644 --- a/src/io/network/stream_reader.hpp +++ b/src/io/network/stream_reader.hpp @@ -46,40 +46,38 @@ class StreamReader : public StreamListener { void OnData(Stream& stream) { logger_.trace("On data"); - while (true) { - if (UNLIKELY(!stream.Alive())) { - logger_.trace("Calling OnClose because the stream isn't alive!"); - this->derived().OnClose(stream); - break; - } - - // allocate the buffer to fill the data - auto buf = this->derived().OnAlloc(stream); - - // read from the buffer at most buf.len bytes - buf.len = stream.socket_.Read(buf.data, buf.len); - - // check for read errors - if (buf.len == -1) { - // this means we have read all available data - if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) { - break; - } - - // some other error occurred, check errno - this->derived().OnError(stream); - break; - } - - // end of file, the client has closed the connection - if (UNLIKELY(buf.len == 0)) { - logger_.trace("Calling OnClose because the socket is closed!"); - this->derived().OnClose(stream); - break; - } - - this->derived().OnRead(stream, buf); + if (UNLIKELY(!stream.Alive())) { + logger_.trace("Calling OnClose because the stream isn't alive!"); + this->derived().OnClose(stream); + return; } + + // allocate the buffer to fill the data + auto buf = this->derived().OnAlloc(stream); + + // read from the buffer at most buf.len bytes + buf.len = stream.socket_.Read(buf.data, buf.len); + + // check for read errors + if (buf.len == -1) { + // this means we have read all available data + if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) { + return; + } + + // some other error occurred, check errno + this->derived().OnError(stream); + return; + } + + // end of file, the client has closed the connection + if (UNLIKELY(buf.len == 0)) { + logger_.trace("Calling OnClose because the socket is closed!"); + this->derived().OnClose(stream); + return; + } + + this->derived().OnRead(stream, buf); } private: diff --git a/tests/concurrent/network_read_hang.cpp b/tests/concurrent/network_read_hang.cpp new file mode 100644 index 000000000..1de6c00c3 --- /dev/null +++ b/tests/concurrent/network_read_hang.cpp @@ -0,0 +1,127 @@ +#ifndef NDEBUG +#define NDEBUG +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "logging/default.hpp" +#include "logging/streams/stdout.hpp" + +#include "communication/server.hpp" +#include "dbms/dbms.hpp" +#include "io/network/epoll.hpp" +#include "io/network/socket.hpp" +#include "query/engine.hpp" + +static constexpr const char interface[] = "127.0.0.1"; + +using endpoint_t = io::network::NetworkEndpoint; +using socket_t = io::network::Socket; + +class TestOutputStream {}; + +class TestSession { + public: + TestSession(socket_t&& socket, Dbms& dbms, + QueryEngine& query_engine) + : socket_(std::move(socket)) { + event_.data.ptr = this; + } + + bool Alive() { return socket_.IsOpen(); } + + int Id() const { return socket_.id(); } + + void Execute(const byte* data, size_t len) { + this->socket_.Write(data, len); + } + + void Close() { + this->socket_.Close(); + } + + socket_t socket_; + io::network::Epoll::Event event_; +}; + +using test_server_t = + communication::Server; + +test_server_t *serverptr; +std::atomic run{true}; + +void client_run(int num, const char* interface, const char* port) { + endpoint_t endpoint(interface, port); + socket_t socket; + uint8_t data = 0x00; + ASSERT_TRUE(socket.Connect(endpoint)); + ASSERT_TRUE(socket.SetTimeout(1, 0)); + // set socket timeout to 1s + ASSERT_TRUE(socket.Write((uint8_t *)"\xAA", 1)); + ASSERT_TRUE(socket.Read(&data, 1)); + fprintf(stderr, "CLIENT %d READ 0x%02X!\n", num, data); + ASSERT_EQ(data, 0xAA); + while (run) + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + socket.Close(); +} + +void server_run(void* serverptr, int num) { + ((test_server_t*)serverptr)->Start(num); +} + +TEST(Network, SocketReadHangOnConcurrentConnections) { + // initialize listen socket + endpoint_t endpoint(interface, "0"); + socket_t socket; + ASSERT_TRUE(socket.Bind(endpoint)); + ASSERT_TRUE(socket.SetNonBlocking()); + ASSERT_TRUE(socket.Listen(1024)); + + // get bound address + auto ep = socket.endpoint(); + printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port()); + + // initialize server + Dbms dbms; + QueryEngine query_engine; + test_server_t server(std::move(socket), dbms, query_engine); + serverptr = &server; + + // start server + int N = (std::thread::hardware_concurrency() + 1) / 2; + int Nc = N * 3; + std::thread server_thread(server_run, serverptr, N); + + // start clients + std::vector clients; + for (int i = 0; i < Nc; ++i) + clients.push_back( + std::thread(client_run, i, interface, ep.port_str())); + + // wait for 2s and stop clients + std::this_thread::sleep_for(std::chrono::seconds(2)); + run = false; + + // cleanup clients + for (int i = 0; i < Nc; ++i) clients[i].join(); + + // stop server + server.Shutdown(); + server_thread.join(); +} + +int main(int argc, char **argv) { + logging::init_async(); + logging::log->pipe(std::make_unique()); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/concurrent/network_server.cpp b/tests/concurrent/network_server.cpp index 5d66f347d..5fa12f081 100644 --- a/tests/concurrent/network_server.cpp +++ b/tests/concurrent/network_server.cpp @@ -5,7 +5,6 @@ #include "network_common.hpp" static constexpr const char interface[] = "127.0.0.1"; -static constexpr const char port[] = "30000"; unsigned char data[SIZE]; @@ -16,12 +15,16 @@ TEST(Network, Server) { initialize_data(data, SIZE); // initialize listen socket - endpoint_t endpoint(interface, port); + endpoint_t endpoint(interface, "0"); socket_t socket; ASSERT_TRUE(socket.Bind(endpoint)); ASSERT_TRUE(socket.SetNonBlocking()); ASSERT_TRUE(socket.Listen(1024)); + // get bound address + auto ep = socket.endpoint(); + printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port()); + // initialize server Dbms dbms; QueryEngine query_engine; @@ -29,14 +32,14 @@ TEST(Network, Server) { serverptr = &server; // start server - int N = std::thread::hardware_concurrency() / 2; + int N = (std::thread::hardware_concurrency() + 1) / 2; std::thread server_thread(server_start, serverptr, N); // start clients std::vector clients; for (int i = 0; i < N; ++i) clients.push_back( - std::thread(client_run, i, interface, port, data, 30000, SIZE)); + std::thread(client_run, i, interface, ep.port_str(), data, 30000, SIZE)); // cleanup clients for (int i = 0; i < N; ++i) clients[i].join(); diff --git a/tests/concurrent/network_session_leak.cpp b/tests/concurrent/network_session_leak.cpp index d18c6bfa3..d3723a8e6 100644 --- a/tests/concurrent/network_session_leak.cpp +++ b/tests/concurrent/network_session_leak.cpp @@ -7,7 +7,6 @@ #include "network_common.hpp" static constexpr const char interface[] = "127.0.0.1"; -static constexpr const char port[] = "40000"; unsigned char data[SIZE]; @@ -20,12 +19,16 @@ TEST(Network, SessionLeak) { initialize_data(data, SIZE); // initialize listen socket - endpoint_t endpoint(interface, port); + endpoint_t endpoint(interface, "0"); socket_t socket; ASSERT_TRUE(socket.Bind(endpoint)); ASSERT_TRUE(socket.SetNonBlocking()); ASSERT_TRUE(socket.Listen(1024)); + // get bound address + auto ep = socket.endpoint(); + printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port()); + // initialize server Dbms dbms; QueryEngine query_engine; @@ -42,7 +45,7 @@ TEST(Network, SessionLeak) { int testlen = 3000; for (int i = 0; i < N; ++i) { clients.push_back( - std::thread(client_run, i, interface, port, data, testlen, testlen)); + std::thread(client_run, i, interface, ep.port_str(), data, testlen, testlen)); std::this_thread::sleep_for(10ms); }