Started working on network hang fix.
Reviewers: buda Reviewed By: buda Subscribers: pullbot, mislav.bradac Differential Revision: https://phabricator.memgraph.io/D228
This commit is contained in:
parent
2032466e2a
commit
db740fb9fc
@ -85,12 +85,21 @@ bool Socket::Bind(NetworkEndpoint& endpoint) {
|
||||
|
||||
if (bind(sfd, it->ai_addr, it->ai_addrlen) == 0) {
|
||||
socket_ = sfd;
|
||||
endpoint_ = endpoint;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (socket_ == -1) return false;
|
||||
|
||||
// detect bound port, used when the server binds to a random port
|
||||
struct sockaddr_in6 portdata;
|
||||
socklen_t portdatalen = sizeof(portdata);
|
||||
if (getsockname(socket_, (struct sockaddr *) &portdata, &portdatalen) < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
endpoint_ = NetworkEndpoint(endpoint.address(), ntohs(portdata.sin6_port));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -128,6 +137,20 @@ bool Socket::SetKeepAlive() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Socket::SetTimeout(long sec, long usec) {
|
||||
struct timeval tv;
|
||||
tv.tv_sec = sec;
|
||||
tv.tv_usec = usec;
|
||||
|
||||
if (setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
|
||||
return false;
|
||||
|
||||
if (setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; }
|
||||
|
||||
bool Socket::Accept(Socket* s) {
|
||||
@ -180,7 +203,10 @@ bool Socket::Write(const char* data, size_t len) {
|
||||
|
||||
bool Socket::Write(const uint8_t* data, size_t len) {
|
||||
while (len > 0) {
|
||||
auto written = send(socket_, data, len, 0);
|
||||
// MSG_NOSIGNAL is here to disable raising a SIGPIPE
|
||||
// signal when a connection dies mid-write, the socket
|
||||
// will only return an EPIPE error
|
||||
auto written = send(socket_, data, len, MSG_NOSIGNAL);
|
||||
if (UNLIKELY(written == -1)) return false;
|
||||
len -= written;
|
||||
data += written;
|
||||
|
@ -97,6 +97,18 @@ class Socket {
|
||||
*/
|
||||
bool SetKeepAlive();
|
||||
|
||||
/**
|
||||
* Sets the socket timeout.
|
||||
*
|
||||
* @param sec timeout seconds value
|
||||
* @param usec timeout microseconds value
|
||||
* @return set socket timeout status:
|
||||
* true if the timeout was successfully set to
|
||||
* sec seconds + usec microseconds
|
||||
* false if the timeout was not set
|
||||
*/
|
||||
bool SetTimeout(long sec, long usec);
|
||||
|
||||
// TODO: this will be removed
|
||||
operator int();
|
||||
|
||||
|
@ -46,40 +46,38 @@ class StreamReader : public StreamListener<Derived, Stream> {
|
||||
void OnData(Stream& stream) {
|
||||
logger_.trace("On data");
|
||||
|
||||
while (true) {
|
||||
if (UNLIKELY(!stream.Alive())) {
|
||||
logger_.trace("Calling OnClose because the stream isn't alive!");
|
||||
this->derived().OnClose(stream);
|
||||
break;
|
||||
}
|
||||
|
||||
// allocate the buffer to fill the data
|
||||
auto buf = this->derived().OnAlloc(stream);
|
||||
|
||||
// read from the buffer at most buf.len bytes
|
||||
buf.len = stream.socket_.Read(buf.data, buf.len);
|
||||
|
||||
// check for read errors
|
||||
if (buf.len == -1) {
|
||||
// this means we have read all available data
|
||||
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// some other error occurred, check errno
|
||||
this->derived().OnError(stream);
|
||||
break;
|
||||
}
|
||||
|
||||
// end of file, the client has closed the connection
|
||||
if (UNLIKELY(buf.len == 0)) {
|
||||
logger_.trace("Calling OnClose because the socket is closed!");
|
||||
this->derived().OnClose(stream);
|
||||
break;
|
||||
}
|
||||
|
||||
this->derived().OnRead(stream, buf);
|
||||
if (UNLIKELY(!stream.Alive())) {
|
||||
logger_.trace("Calling OnClose because the stream isn't alive!");
|
||||
this->derived().OnClose(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
// allocate the buffer to fill the data
|
||||
auto buf = this->derived().OnAlloc(stream);
|
||||
|
||||
// read from the buffer at most buf.len bytes
|
||||
buf.len = stream.socket_.Read(buf.data, buf.len);
|
||||
|
||||
// check for read errors
|
||||
if (buf.len == -1) {
|
||||
// this means we have read all available data
|
||||
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// some other error occurred, check errno
|
||||
this->derived().OnError(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
// end of file, the client has closed the connection
|
||||
if (UNLIKELY(buf.len == 0)) {
|
||||
logger_.trace("Calling OnClose because the socket is closed!");
|
||||
this->derived().OnClose(stream);
|
||||
return;
|
||||
}
|
||||
|
||||
this->derived().OnRead(stream, buf);
|
||||
}
|
||||
|
||||
private:
|
||||
|
127
tests/concurrent/network_read_hang.cpp
Normal file
127
tests/concurrent/network_read_hang.cpp
Normal file
@ -0,0 +1,127 @@
|
||||
#ifndef NDEBUG
|
||||
#define NDEBUG
|
||||
#endif
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "logging/default.hpp"
|
||||
#include "logging/streams/stdout.hpp"
|
||||
|
||||
#include "communication/server.hpp"
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "query/engine.hpp"
|
||||
|
||||
static constexpr const char interface[] = "127.0.0.1";
|
||||
|
||||
using endpoint_t = io::network::NetworkEndpoint;
|
||||
using socket_t = io::network::Socket;
|
||||
|
||||
class TestOutputStream {};
|
||||
|
||||
class TestSession {
|
||||
public:
|
||||
TestSession(socket_t&& socket, Dbms& dbms,
|
||||
QueryEngine<TestOutputStream>& query_engine)
|
||||
: socket_(std::move(socket)) {
|
||||
event_.data.ptr = this;
|
||||
}
|
||||
|
||||
bool Alive() { return socket_.IsOpen(); }
|
||||
|
||||
int Id() const { return socket_.id(); }
|
||||
|
||||
void Execute(const byte* data, size_t len) {
|
||||
this->socket_.Write(data, len);
|
||||
}
|
||||
|
||||
void Close() {
|
||||
this->socket_.Close();
|
||||
}
|
||||
|
||||
socket_t socket_;
|
||||
io::network::Epoll::Event event_;
|
||||
};
|
||||
|
||||
using test_server_t =
|
||||
communication::Server<TestSession, TestOutputStream, socket_t>;
|
||||
|
||||
test_server_t *serverptr;
|
||||
std::atomic<bool> run{true};
|
||||
|
||||
void client_run(int num, const char* interface, const char* port) {
|
||||
endpoint_t endpoint(interface, port);
|
||||
socket_t socket;
|
||||
uint8_t data = 0x00;
|
||||
ASSERT_TRUE(socket.Connect(endpoint));
|
||||
ASSERT_TRUE(socket.SetTimeout(1, 0));
|
||||
// set socket timeout to 1s
|
||||
ASSERT_TRUE(socket.Write((uint8_t *)"\xAA", 1));
|
||||
ASSERT_TRUE(socket.Read(&data, 1));
|
||||
fprintf(stderr, "CLIENT %d READ 0x%02X!\n", num, data);
|
||||
ASSERT_EQ(data, 0xAA);
|
||||
while (run)
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
socket.Close();
|
||||
}
|
||||
|
||||
void server_run(void* serverptr, int num) {
|
||||
((test_server_t*)serverptr)->Start(num);
|
||||
}
|
||||
|
||||
TEST(Network, SocketReadHangOnConcurrentConnections) {
|
||||
// initialize listen socket
|
||||
endpoint_t endpoint(interface, "0");
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Bind(endpoint));
|
||||
ASSERT_TRUE(socket.SetNonBlocking());
|
||||
ASSERT_TRUE(socket.Listen(1024));
|
||||
|
||||
// get bound address
|
||||
auto ep = socket.endpoint();
|
||||
printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port());
|
||||
|
||||
// initialize server
|
||||
Dbms dbms;
|
||||
QueryEngine<TestOutputStream> query_engine;
|
||||
test_server_t server(std::move(socket), dbms, query_engine);
|
||||
serverptr = &server;
|
||||
|
||||
// start server
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
int Nc = N * 3;
|
||||
std::thread server_thread(server_run, serverptr, N);
|
||||
|
||||
// start clients
|
||||
std::vector<std::thread> clients;
|
||||
for (int i = 0; i < Nc; ++i)
|
||||
clients.push_back(
|
||||
std::thread(client_run, i, interface, ep.port_str()));
|
||||
|
||||
// wait for 2s and stop clients
|
||||
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||
run = false;
|
||||
|
||||
// cleanup clients
|
||||
for (int i = 0; i < Nc; ++i) clients[i].join();
|
||||
|
||||
// stop server
|
||||
server.Shutdown();
|
||||
server_thread.join();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
logging::init_async();
|
||||
logging::log->pipe(std::make_unique<Stdout>());
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -5,7 +5,6 @@
|
||||
#include "network_common.hpp"
|
||||
|
||||
static constexpr const char interface[] = "127.0.0.1";
|
||||
static constexpr const char port[] = "30000";
|
||||
|
||||
unsigned char data[SIZE];
|
||||
|
||||
@ -16,12 +15,16 @@ TEST(Network, Server) {
|
||||
initialize_data(data, SIZE);
|
||||
|
||||
// initialize listen socket
|
||||
endpoint_t endpoint(interface, port);
|
||||
endpoint_t endpoint(interface, "0");
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Bind(endpoint));
|
||||
ASSERT_TRUE(socket.SetNonBlocking());
|
||||
ASSERT_TRUE(socket.Listen(1024));
|
||||
|
||||
// get bound address
|
||||
auto ep = socket.endpoint();
|
||||
printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port());
|
||||
|
||||
// initialize server
|
||||
Dbms dbms;
|
||||
QueryEngine<TestOutputStream> query_engine;
|
||||
@ -29,14 +32,14 @@ TEST(Network, Server) {
|
||||
serverptr = &server;
|
||||
|
||||
// start server
|
||||
int N = std::thread::hardware_concurrency() / 2;
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
std::thread server_thread(server_start, serverptr, N);
|
||||
|
||||
// start clients
|
||||
std::vector<std::thread> clients;
|
||||
for (int i = 0; i < N; ++i)
|
||||
clients.push_back(
|
||||
std::thread(client_run, i, interface, port, data, 30000, SIZE));
|
||||
std::thread(client_run, i, interface, ep.port_str(), data, 30000, SIZE));
|
||||
|
||||
// cleanup clients
|
||||
for (int i = 0; i < N; ++i) clients[i].join();
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include "network_common.hpp"
|
||||
|
||||
static constexpr const char interface[] = "127.0.0.1";
|
||||
static constexpr const char port[] = "40000";
|
||||
|
||||
unsigned char data[SIZE];
|
||||
|
||||
@ -20,12 +19,16 @@ TEST(Network, SessionLeak) {
|
||||
initialize_data(data, SIZE);
|
||||
|
||||
// initialize listen socket
|
||||
endpoint_t endpoint(interface, port);
|
||||
endpoint_t endpoint(interface, "0");
|
||||
socket_t socket;
|
||||
ASSERT_TRUE(socket.Bind(endpoint));
|
||||
ASSERT_TRUE(socket.SetNonBlocking());
|
||||
ASSERT_TRUE(socket.Listen(1024));
|
||||
|
||||
// get bound address
|
||||
auto ep = socket.endpoint();
|
||||
printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port());
|
||||
|
||||
// initialize server
|
||||
Dbms dbms;
|
||||
QueryEngine<TestOutputStream> query_engine;
|
||||
@ -42,7 +45,7 @@ TEST(Network, SessionLeak) {
|
||||
int testlen = 3000;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
clients.push_back(
|
||||
std::thread(client_run, i, interface, port, data, testlen, testlen));
|
||||
std::thread(client_run, i, interface, ep.port_str(), data, testlen, testlen));
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user