From 017e8004e8d5bc2b033ef382e12a589bf85c0e49 Mon Sep 17 00:00:00 2001 From: Matej Ferencevic Date: Thu, 22 Feb 2018 16:17:45 +0100 Subject: [PATCH] Refactor network stack Summary: Previously, the network stack `communication::Server` accepted connections and assigned them statically in a round-robin fashion to `communication::Worker`. That meant that if two compute intensive connections were assigned to the same worker they would block each other while the other workers would do nothing. This implementation replaces `communication::Worker` with `communication::Listener` which holds all accepted connections in one pool and ensures that all workers execute all connections. Reviewers: buda, florijan, teon.banek Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1220 --- src/communication/bolt/v1/session.hpp | 87 ++++--- .../bolt/v1/states/handshake.hpp | 2 +- src/communication/listener.hpp | 246 ++++++++++++++++++ src/communication/rpc/protocol.cpp | 10 - src/communication/rpc/protocol.hpp | 11 - src/communication/rpc/server.cpp | 2 +- src/communication/server.hpp | 131 +++------- src/communication/worker.hpp | 194 -------------- src/io/network/epoll.hpp | 65 ++++- src/io/network/socket.cpp | 22 +- src/io/network/socket.hpp | 21 +- src/io/network/socket_event_dispatcher.hpp | 95 ------- src/memgraph_bolt.cpp | 4 +- tests/concurrent/network_common.hpp | 6 - tests/concurrent/network_read_hang.cpp | 6 +- tests/concurrent/network_server.cpp | 2 +- tests/concurrent/network_session_leak.cpp | 3 +- tests/unit/bolt_common.hpp | 3 - tests/unit/bolt_session.cpp | 105 ++++---- tests/unit/network_timeouts.cpp | 10 +- 20 files changed, 462 insertions(+), 563 deletions(-) create mode 100644 src/communication/listener.hpp delete mode 100644 src/communication/worker.hpp delete mode 100644 src/io/network/socket_event_dispatcher.hpp diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index 8be06befd..d677d1d63 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "glog/logging.h" #include "communication/bolt/v1/constants.hpp" @@ -17,7 +19,9 @@ #include "io/network/socket.hpp" #include "io/network/stream_buffer.hpp" #include "query/interpreter.hpp" +#include "threading/sync/spinlock.hpp" #include "transactions/transaction.hpp" +#include "utils/exceptions.hpp" DECLARE_int32(session_inactivity_timeout); @@ -30,6 +34,16 @@ struct SessionData { query::Interpreter interpreter; }; +/** + * Bolt Session Exception + * + * Used to indicate that something went wrong during the session execution. + */ +class SessionException : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + /** * Bolt Session * @@ -40,25 +54,7 @@ struct SessionData { template class Session { public: - // Wrapper around socket that checks if session has timed out on write - // failures, used in encoder buffer. - class TimeoutSocket { - public: - explicit TimeoutSocket(Session &session) : session_(session) {} - - bool Write(const uint8_t *data, size_t len) { - // The have_more flag is hardcoded to false here because the bolt data - // is internally buffered and doesn't need to be buffered by the kernel. - return session_.socket_.Write(data, len, false, - [this] { return !session_.TimedOut(); }); - } - - private: - Session &session_; - }; - - using ResultStreamT = - ResultStream>>; + using ResultStreamT = ResultStream>>; using StreamBuffer = io::network::StreamBuffer; Session(TSocket &&socket, SessionData &data) @@ -67,8 +63,9 @@ class Session { interpreter_(data.interpreter) {} ~Session() { - DCHECK(!db_accessor_) - << "Transaction should have already be closed in Close"; + if (db_accessor_) { + Abort(); + } } /** @@ -158,8 +155,12 @@ class Session { * Returns true if session has timed out. Session times out if there was no * activity in FLAGS_sessions_inactivity_timeout seconds or if there is a * active transaction with shoul_abort flag set to true. + * This function must be thread safe because this function and + * `RefreshLastEventTime` are called from different threads in the + * network stack. */ - bool TimedOut() const { + bool TimedOut() { + std::unique_lock guard(lock_); return db_accessor_ ? db_accessor_->should_abort() : last_event_time_ + std::chrono::seconds( @@ -167,17 +168,6 @@ class Session { std::chrono::steady_clock::now(); } - /** - * Closes the session (client socket). - */ - void Close() { - DLOG(INFO) << "Closing session"; - if (db_accessor_) { - Abort(); - } - this->socket_.Close(); - } - /** * Commits associated transaction. */ @@ -198,9 +188,16 @@ class Session { TSocket &socket() { return socket_; } + /** + * Function that is called by the network stack to set the last event time. + * It is used to determine whether the session has timed out. + * This function must be thread safe because this function and + * `TimedOut` are called from different threads in the network stack. + */ void RefreshLastEventTime( const std::chrono::time_point &last_event_time) { + std::unique_lock guard(lock_); last_event_time_ = last_event_time; } @@ -210,9 +207,8 @@ class Session { database::MasterBase &db_; query::Interpreter &interpreter_; - TimeoutSocket timeout_socket_{*this}; - ChunkedEncoderBuffer encoder_buffer_{timeout_socket_}; - Encoder> encoder_{encoder_buffer_}; + ChunkedEncoderBuffer encoder_buffer_{socket_}; + Encoder> encoder_{encoder_buffer_}; ResultStreamT output_stream_{encoder_}; Buffer<> buffer_; @@ -224,9 +220,11 @@ class Session { // GraphDbAccessor of active transaction in the session, can be null if // there is no associated transaction. std::unique_ptr db_accessor_; - // Time of the last event. + + // Time of the last event and associated lock. std::chrono::time_point last_event_time_ = std::chrono::steady_clock::now(); + SpinLock lock_; private: void ClientFailureInvalidData() { @@ -235,10 +233,17 @@ class Session { // We don't care about the return status because this is called when we // are about to close the connection to the client. encoder_buffer_.Clear(); - encoder_.MessageFailure({{"code", "Memgraph.InvalidData"}, - {"message", "The client has sent invalid data!"}}); - // Close the connection. - Close(); + encoder_.MessageFailure({{"code", "Memgraph.ExecutionException"}, + {"message", + "Something went wrong while executing the query! " + "Check the server logs for more details."}}); + // Throw an exception to indicate that something went wrong with execution + // of the session to trigger session cleanup and socket close. + if (TimedOut()) { + throw SessionException("The session has timed out!"); + } else { + throw SessionException("The client has sent invalid data!"); + } } }; } // namespace communication::bolt diff --git a/src/communication/bolt/v1/states/handshake.hpp b/src/communication/bolt/v1/states/handshake.hpp index 833525909..6ed75474a 100644 --- a/src/communication/bolt/v1/states/handshake.hpp +++ b/src/communication/bolt/v1/states/handshake.hpp @@ -25,7 +25,7 @@ State StateHandshakeRun(TSession &session) { // make sense to check which version the client prefers this will change in // the future. - if (!session.timeout_socket_.Write(kProtocol, sizeof(kProtocol))) { + if (!session.socket_.Write(kProtocol, sizeof(kProtocol))) { DLOG(WARNING) << "Couldn't write handshake response!"; return State::Close; } diff --git a/src/communication/listener.hpp b/src/communication/listener.hpp new file mode 100644 index 000000000..a174b1a76 --- /dev/null +++ b/src/communication/listener.hpp @@ -0,0 +1,246 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "io/network/epoll.hpp" +#include "io/network/socket.hpp" +#include "threading/sync/spinlock.hpp" + +namespace communication { + +/** + * This class listens to events on an epoll object and processes them. + * When a new connection is added a `TSession` object is created to handle the + * connection. When the `TSession` handler raises an exception or an error + * occurs the `TSession` object is deleted and the corresponding socket is + * closed. Also, this class has a background thread that periodically, every + * second, checks all sessions for expiration and shuts them down if they have + * expired. + */ +template +class Listener { + private: + // The maximum number of events handled per execution thread is 1. This is + // because each event represents the start of a network request and it doesn't + // make sense to take more than one event because the processing of an event + // can take a long time. + static const int kMaxEvents = 1; + + public: + Listener(TSessionData &data, bool check_for_timeouts) + : data_(data), alive_(true) { + if (check_for_timeouts) { + thread_ = std::thread([this]() { + while (alive_) { + { + std::unique_lock guard(lock_); + for (auto &session : sessions_) { + if (session->TimedOut()) { + LOG(WARNING) << "Session associated with " + << session->socket().endpoint() << " timed out."; + // Here we shutdown the socket to terminate any leftover + // blocking `Write` calls and to signal an event that the + // session is closed. Session cleanup will be done in the event + // process function. + session->socket().Shutdown(); + } + } + } + // TODO (mferencevic): Should this be configurable? + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + }); + } + } + + ~Listener() { + alive_.store(false); + if (thread_.joinable()) thread_.join(); + } + + Listener(const Listener &) = delete; + Listener(Listener &&) = delete; + Listener &operator=(const Listener &) = delete; + Listener &operator=(Listener &&) = delete; + + /** + * This function adds a socket to the listening event pool. + * + * @param connection socket which should be added to the event pool + */ + void AddConnection(io::network::Socket &&connection) { + std::unique_lock guard(lock_); + + // Set connection options. + // The socket is left to be a blocking socket, but when `Read` is called + // then a flag is manually set to enable non-blocking read that is used in + // conjunction with `EPOLLET`. That means that the socket is used in a + // non-blocking fashion for reads and a blocking fashion for writes. + connection.SetKeepAlive(); + connection.SetNoDelay(); + + // Remember fd before moving connection into Session. + int fd = connection.fd(); + + // Create a new Session for the connection. + sessions_.push_back( + std::make_unique(std::move(connection), data_)); + + // Register the connection in Epoll. + // We want to listen to an incoming event which is edge triggered and + // we also want to listen on the hangup event. Epoll is hard to use + // concurrently and that is why we use `EPOLLONESHOT`, for a detailed + // description what are the problems and why this is correct see: + // https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/ + epoll_.Add(fd, EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, + sessions_.back().get()); + } + + /** + * This function polls the event queue and processes incoming data. + * It is thread safe and is intended to be called from multiple threads and + * doesn't block the calling threads. + */ + void WaitAndProcessEvents() { + // This array can't be global because this function can be called from + // multiple threads, therefore, it must be on the stack. + io::network::Epoll::Event events[kMaxEvents]; + + // Waits for an events and returns a maximum of max_events (1) + // and stores them in the events array. It waits for wait_timeout + // milliseconds. If wait_timeout is achieved, returns 0. + int n = epoll_.Wait(events, kMaxEvents, 200); + if (n <= 0) return; + + // Process the event. + auto &event = events[0]; + + // We get the currently associated Session pointer and immediately + // dereference it here. It is safe to dereference the pointer because + // this design guarantees that there will never be an event that has + // a stale Session pointer. + TSession &session = *reinterpret_cast(event.data.ptr); + + // Process epoll events. We use epoll in edge-triggered mode so we process + // all events here. Only one of the `if` statements must be executed + // because each of them can call `CloseSession` which destroys the session + // and calling a function on that session after that would cause a + // segfault. + if (event.events & EPOLLIN) { + // Read and process all incoming data. + while (ReadAndProcessSession(session)) + ; + } else if (event.events & EPOLLRDHUP) { + // The client closed the connection. + LOG(INFO) << "Client " << session.socket().endpoint() + << " closed the connection."; + CloseSession(session); + } else if (!(event.events & EPOLLIN) || + event.events & (EPOLLHUP | EPOLLERR)) { + // There was an error on the server side. + LOG(ERROR) << "Error occured in session associated with " + << session.socket().endpoint(); + CloseSession(session); + } else { + // Unhandled epoll event. + LOG(ERROR) << "Unhandled event occured in session associated with " + << session.socket().endpoint() << " events: " << event.events; + CloseSession(session); + } + } + + private: + bool ReadAndProcessSession(TSession &session) { + // Refresh the last event time in the session. + // This function must be implemented thread safe. + session.RefreshLastEventTime(std::chrono::steady_clock::now()); + + // Allocate the buffer to fill the data. + auto buf = session.Allocate(); + // Read from the buffer at most buf.len bytes in a non-blocking fashion. + int len = session.socket().Read(buf.data, buf.len, true); + + // Check for read errors. + if (len == -1) { + // This means read would block or read was interrupted by signal, we + // return `false` to stop reading of data. + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + // Rearm epoll to send events from this socket. + epoll_.Modify(session.socket().fd(), + EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session); + return false; + } + // Some other error occurred, close the session. + CloseSession(session); + return false; + } + + // The client has closed the connection. + if (len == 0) { + LOG(INFO) << "Client " << session.socket().endpoint() + << " closed the connection."; + CloseSession(session); + return false; + } + + // Notify the session that it has new data. + session.Written(len); + + // Execute the session. + try { + session.Execute(); + session.RefreshLastEventTime(std::chrono::steady_clock::now()); + } catch (const std::exception &e) { + // Catch all exceptions. + LOG(ERROR) << "Exception was thrown while processing event in session " + "associated with " + << session.socket().endpoint() + << " with message: " << e.what(); + CloseSession(session); + return false; + } + + return true; + } + + void CloseSession(TSession &session) { + // Deregister the Session's socket from epoll to disable further events. For + // a detailed description why this is necessary before destroying (closing) + // the socket, see: + // https://idea.popcount.org/2017-03-20-epoll-is-fundamentally-broken-22/ + epoll_.Delete(session.socket().fd()); + + std::unique_lock guard(lock_); + auto it = + std::find_if(sessions_.begin(), sessions_.end(), + [&](const auto &l) { return l->Id() == session.Id(); }); + + CHECK(it != sessions_.end()) + << "Trying to remove session that is not found in sessions!"; + int i = it - sessions_.begin(); + swap(sessions_[i], sessions_.back()); + + // This will call all destructors on the Session. Consequently, it will call + // the destructor on the Socket and close the socket. + sessions_.pop_back(); + } + + io::network::Epoll epoll_; + + TSessionData &data_; + + SpinLock lock_; + std::vector> sessions_; + + std::thread thread_; + std::atomic alive_; +}; +} // namespace communication diff --git a/src/communication/rpc/protocol.cpp b/src/communication/rpc/protocol.cpp index a652eb6b4..0f4119243 100644 --- a/src/communication/rpc/protocol.cpp +++ b/src/communication/rpc/protocol.cpp @@ -16,8 +16,6 @@ namespace communication::rpc { Session::Session(Socket &&socket, System &system) : socket_(std::make_shared(std::move(socket))), system_(system) {} -bool Session::Alive() const { return alive_; } - void Session::Execute() { if (!handshake_done_) { if (buffer_.size() < sizeof(MessageSize)) return; @@ -59,14 +57,6 @@ StreamBuffer Session::Allocate() { return buffer_.Allocate(); } void Session::Written(size_t len) { buffer_.Written(len); } -void Session::Close() { - DLOG(INFO) << "Closing session"; - // We explicitly close the socket here to remove the socket from the epoll - // event loop. The response message send will fail but that is OK and - // intended because the remote side closed the connection. - socket_.get()->Close(); -} - void SendMessage(Socket &socket, uint32_t message_id, std::unique_ptr &message) { CHECK(message) << "Trying to send nullptr instead of message"; diff --git a/src/communication/rpc/protocol.hpp b/src/communication/rpc/protocol.hpp index d90b4bc4b..94625a686 100644 --- a/src/communication/rpc/protocol.hpp +++ b/src/communication/rpc/protocol.hpp @@ -41,11 +41,6 @@ class Session { int Id() const { return socket_->fd(); } - /** - * Returns the protocol alive state - */ - bool Alive() const; - /** * Executes the protocol after data has been read into the buffer. * Goes through the protocol states in order to execute commands from the @@ -71,11 +66,6 @@ class Session { bool TimedOut() { return false; } - /** - * Closes the session (client socket). - */ - void Close(); - Socket &socket() { return *socket_; } void RefreshLastEventTime( @@ -93,7 +83,6 @@ class Session { std::string service_name_; bool handshake_done_{false}; - bool alive_{true}; Buffer buffer_; }; diff --git a/src/communication/rpc/server.cpp b/src/communication/rpc/server.cpp index 4d23e0011..f40eada14 100644 --- a/src/communication/rpc/server.cpp +++ b/src/communication/rpc/server.cpp @@ -13,7 +13,7 @@ namespace communication::rpc { System::System(const io::network::Endpoint &endpoint, const size_t workers_count) - : server_(endpoint, *this, workers_count) {} + : server_(endpoint, *this, false, workers_count) {} System::~System() {} diff --git a/src/communication/server.hpp b/src/communication/server.hpp index 051bf4494..75de7b288 100644 --- a/src/communication/server.hpp +++ b/src/communication/server.hpp @@ -10,24 +10,21 @@ #include #include -#include "communication/worker.hpp" +#include "communication/listener.hpp" #include "io/network/socket.hpp" -#include "io/network/socket_event_dispatcher.hpp" namespace communication { -/** - * TODO (mferencevic): document methods - */ - /** * Communication server. - * Listens for incomming connections on the server port and assigns them in a - * round-robin manner to it's workers. Started automatically on constructor, and - * stopped at destructor. + * + * Listens for incoming connections on the server port and assigns them to the + * connection listener. The listener processes the events with a thread pool + * that has `num_workers` threads. It is started automatically on constructor, + * and stopped at destructor. * * Current Server achitecture: - * incomming connection -> server -> worker -> session + * incoming connection -> server -> listener -> session * * @tparam TSession the server can handle different Sessions, each session * represents a different protocol so the same network infrastructure @@ -38,7 +35,6 @@ namespace communication { template class Server { public: - using WorkerT = Worker; using Socket = io::network::Socket; /** @@ -46,43 +42,35 @@ class Server { * invokes workers_count workers */ Server(const io::network::Endpoint &endpoint, TSessionData &session_data, + bool check_for_timeouts, size_t workers_count = std::thread::hardware_concurrency()) - : session_data_(session_data) { + : listener_(session_data, check_for_timeouts) { // Without server we can't continue with application so we can just // terminate here. if (!socket_.Bind(endpoint)) { - LOG(FATAL) << "Cannot bind to socket on " << endpoint.address() << " at " - << endpoint.port(); + LOG(FATAL) << "Cannot bind to socket on " << endpoint; } - socket_.SetNonBlocking(); + socket_.SetTimeout(1, 0); if (!socket_.Listen(1024)) { LOG(FATAL) << "Cannot listen on socket!"; } - working_thread_ = std::thread([this, workers_count]() { + + thread_ = std::thread([this, workers_count]() { std::cout << fmt::format("Starting {} workers", workers_count) << std::endl; - workers_.reserve(workers_count); for (size_t i = 0; i < workers_count; ++i) { - workers_.push_back(std::make_unique(session_data_)); - worker_threads_.emplace_back( - [this](WorkerT &worker) -> void { worker.Start(alive_); }, - std::ref(*workers_.back())); + worker_threads_.emplace_back([this]() { + while (alive_) { + listener_.WaitAndProcessEvents(); + } + }); } - std::cout << "Server is fully armed and operational" << std::endl; - std::cout << fmt::format("Listening on {} at {}", - socket_.endpoint().address(), - socket_.endpoint().port()) - << std::endl; - std::vector> acceptors; - acceptors.emplace_back( - std::make_unique(socket_, *this)); - auto &acceptor = *acceptors.back().get(); - io::network::SocketEventDispatcher dispatcher{ - acceptors}; - dispatcher.AddListener(socket_.fd(), acceptor, EPOLLIN); + std::cout << "Server is fully armed and operational" << std::endl; + std::cout << "Listening on " << socket_.endpoint() << std::endl; + while (alive_) { - dispatcher.WaitAndProcessEvents(); + AcceptConnection(); } std::cout << "Shutting down..." << std::endl; @@ -97,6 +85,11 @@ class Server { AwaitShutdown(); } + Server(const Server &) = delete; + Server(Server &&) = delete; + Server &operator=(const Server &) = delete; + Server &operator=(Server &&) = delete; + const auto &endpoint() const { return socket_.endpoint(); } /// Stops server manually @@ -104,73 +97,33 @@ class Server { // This should be as simple as possible, so that it can be called inside a // signal handler. alive_.store(false); + // Shutdown the socket to return from any waiting `Accept` calls. + socket_.Shutdown(); } /// Waits for the server to be signaled to shutdown void AwaitShutdown() { - if (working_thread_.joinable()) working_thread_.join(); + if (thread_.joinable()) thread_.join(); } private: - class ConnectionAcceptor { - public: - ConnectionAcceptor(Socket &socket, Server &server) - : socket_(socket), server_(server) {} - - void OnData() { - DCHECK(server_.idx_ < server_.workers_.size()) << "Invalid worker id."; - DLOG(INFO) << "On connect"; - auto connection = AcceptConnection(); - if (!connection) { - // Connection is not available anymore or configuration failed. - return; - } - server_.workers_[server_.idx_]->AddConnection(std::move(*connection)); - server_.idx_ = (server_.idx_ + 1) % server_.workers_.size(); + void AcceptConnection() { + // Accept a connection from a socket. + auto s = socket_.Accept(); + if (!s) { + // Connection is not available anymore or configuration failed. + return; } + LOG(INFO) << "Accepted a connection from " << s->endpoint(); + listener_.AddConnection(std::move(*s)); + } - void OnClose() { socket_.Close(); } - - void OnException(const std::exception &e) { - LOG(FATAL) << "Exception was thrown while processing event on socket " - << socket_.fd() << " with message: " << e.what(); - } - - void OnError() { LOG(FATAL) << "Error on server side occured in epoll"; } - - private: - // Accepts connection on socket_ and configures new connections. If done - // successfuly new socket (connection) is returner, nullopt otherwise. - std::experimental::optional AcceptConnection() { - DLOG(INFO) << "Accept new connection on socket: " << socket_.fd(); - - // Accept a connection from a socket. - auto s = socket_.Accept(); - if (!s) return std::experimental::nullopt; - - DLOG(INFO) << fmt::format( - "Accepted a connection: socket {}, address '{}', family {}, port {}", - s->fd(), s->endpoint().address(), s->endpoint().family(), - s->endpoint().port()); - - s->SetTimeout(1, 0); - s->SetKeepAlive(); - s->SetNoDelay(); - return s; - } - - Socket &socket_; - Server &server_; - }; - - std::vector> workers_; - std::vector worker_threads_; - std::thread working_thread_; std::atomic alive_{true}; - int idx_{0}; + std::thread thread_; + std::vector worker_threads_; Socket socket_; - TSessionData &session_data_; + Listener listener_; }; } // namespace communication diff --git a/src/communication/worker.hpp b/src/communication/worker.hpp deleted file mode 100644 index 92986bdc1..000000000 --- a/src/communication/worker.hpp +++ /dev/null @@ -1,194 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "io/network/network_error.hpp" -#include "io/network/socket.hpp" -#include "io/network/socket_event_dispatcher.hpp" -#include "io/network/stream_buffer.hpp" -#include "threading/sync/spinlock.hpp" - -namespace communication { - -/** - * TODO (mferencevic): document methods - */ - -/** - * Communication worker. - * Listens for incomming data on connections and accepts new connections. - * Also, executes sessions on incomming data. - * - * @tparam TSession the worker can handle different Sessions, each session - * represents a different protocol so the same network infrastructure - * can be used for handling different protocols - * @tparam TSessionData the class with objects that will be forwarded to the - * session - */ -template -class Worker { - using Socket = io::network::Socket; - - public: - void AddConnection(Socket &&connection) { - std::unique_lock guard(lock_); - // Remember fd before moving connection into SessionListener. - int fd = connection.fd(); - session_listeners_.push_back( - std::make_unique(std::move(connection), *this)); - // We want to listen to an incoming event which is edge triggered and - // we also want to listen on the hangup event. - dispatcher_.AddListener(fd, *session_listeners_.back(), - EPOLLIN | EPOLLRDHUP); - } - - explicit Worker(TSessionData &session_data) : session_data_(session_data) {} - - void Start(std::atomic &alive) { - while (alive) { - dispatcher_.WaitAndProcessEvents(); - - bool check_sessions_for_timeouts = true; - while (check_sessions_for_timeouts) { - check_sessions_for_timeouts = false; - - std::unique_lock guard(lock_); - for (auto &session_listener : session_listeners_) { - if (session_listener->session().TimedOut()) { - // We need to unlock here, because OnSessionAndTxTimeout will need - // to acquire same lock. - guard.unlock(); - session_listener->OnSessionAndTxTimeout(); - // Since we released lock we can't continue iteration so we need to - // break. There could still be more sessions that timed out so we - // set check_sessions_for_timeout back to true. - check_sessions_for_timeouts = true; - break; - } - } - } - } - } - - private: - // TODO: Think about ownership. Who should own socket session, - // SessionSocketListener or Worker? - class SessionSocketListener { - public: - SessionSocketListener(Socket &&socket, - Worker &worker) - : session_(std::move(socket), worker.session_data_), worker_(worker) {} - - auto &session() { return session_; } - const auto &session() const { return session_; } - const auto &TimedOut() const { return session_.TimedOut(); } - - void OnData() { - session_.RefreshLastEventTime(std::chrono::steady_clock::now()); - DLOG(INFO) << "On data"; - // allocate the buffer to fill the data - auto buf = session_.Allocate(); - // read from the buffer at most buf.len bytes - int len = session_.socket().Read(buf.data, buf.len); - - // check for read errors - if (len == -1) { - // This means read would block or read was interrupted by signal. - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - return; - } - // Some other error occurred, check errno. - OnError(); - return; - } - // The client has closed the connection. - if (len == 0) { - DLOG(WARNING) << "Calling OnClose because the socket is closed!"; - OnClose(); - return; - } - - // Notify the stream that it has new data. - session_.Written(len); - DLOG(INFO) << "OnRead"; - try { - session_.Execute(); - } catch (const std::exception &e) { - LOG(ERROR) << "Error occured while executing statement with message: " - << e.what(); - OnError(); - } - session_.RefreshLastEventTime(std::chrono::steady_clock::now()); - } - - // TODO: Remove duplication in next three functions. - void OnError() { - LOG(ERROR) << fmt::format( - "Error occured in session associated with {}:{}", - session_.socket().endpoint().address(), - session_.socket().endpoint().port()); - CloseSession(); - } - - void OnException(const std::exception &e) { - LOG(ERROR) << fmt::format( - "Exception was thrown while processing event in session associated " - "with {}:{} with message: {}", - session_.socket().endpoint().address(), - session_.socket().endpoint().port(), e.what()); - CloseSession(); - } - - void OnSessionAndTxTimeout() { - LOG(WARNING) << fmt::format( - "Session or transaction associated with {}:{} timed out.", - session_.socket().endpoint().address(), - session_.socket().endpoint().port()); - // TODO: report to client what happend. - CloseSession(); - } - - void OnClose() { - LOG(INFO) << fmt::format("Client {}:{} closed the connection.", - session_.socket().endpoint().address(), - session_.socket().endpoint().port()); - CloseSession(); - } - - private: - void CloseSession() { - session_.Close(); - - std::unique_lock guard(worker_.lock_); - auto it = std::find_if( - worker_.session_listeners_.begin(), worker_.session_listeners_.end(), - [&](const auto &l) { return l->session_.Id() == session_.Id(); }); - - CHECK(it != worker_.session_listeners_.end()) - << "Trying to remove session that is not found in worker's sessions"; - int i = it - worker_.session_listeners_.begin(); - swap(worker_.session_listeners_[i], worker_.session_listeners_.back()); - worker_.session_listeners_.pop_back(); - } - - TSession session_; - Worker &worker_; - }; - - SpinLock lock_; - TSessionData &session_data_; - std::vector> session_listeners_; - io::network::SocketEventDispatcher dispatcher_{session_listeners_}; -}; -} // namespace communication diff --git a/src/io/network/epoll.hpp b/src/io/network/epoll.hpp index 5e3814741..d82ce7db7 100644 --- a/src/io/network/epoll.hpp +++ b/src/io/network/epoll.hpp @@ -21,31 +21,78 @@ class Epoll { public: using Event = struct epoll_event; - explicit Epoll(int flags) : epoll_fd_(epoll_create1(flags)) { + Epoll(bool set_cloexec = false) + : epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) { // epoll_create1 returns an error if there is a logical error in our code // (for example invalid flags) or if there is irrecoverable error. In both // cases it is best to terminate. - CHECK(epoll_fd_ != -1) << "Error on epoll_create1, errno: " << errno - << ", message: " << strerror(errno); + CHECK(epoll_fd_ != -1) << "Error on epoll create: (" << errno << ") " + << strerror(errno); } - void Add(int fd, Event *event) { - auto status = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, event); + /** + * This function adds/modifies a file descriptor to be listened for events. + * + * @param fd file descriptor to add to epoll + * @param events epoll events mask + * @param ptr pointer to the associated event handler + * @param modify modify an existing file descriptor + */ + void Add(int fd, uint32_t events, void *ptr, bool modify = false) { + Event event; + event.events = events; + event.data.ptr = ptr; + int status = epoll_ctl(epoll_fd_, (modify ? EPOLL_CTL_MOD : EPOLL_CTL_ADD), + fd, &event); // epoll_ctl can return an error on our logical error or on irrecoverable // error. There is a third possibility that some system limit is reached. In // that case we could return an erorr and close connection. Chances of // reaching system limit in normally working memgraph is extremely unlikely, // so it is correct to terminate even in that case. - CHECK(!status) << "Error on epoll_ctl, errno: " << errno - << ", message: " << strerror(errno); + CHECK(!status) << "Error on epoll " << (modify ? "modify" : "add") << ": (" + << errno << ") " << strerror(errno); } + /** + * This function modifies a file descriptor that is listened for events. + * + * @param fd file descriptor to modify in epoll + * @param events epoll events mask + * @param ptr pointer to the associated event handler + */ + void Modify(int fd, uint32_t events, void *ptr) { + Add(fd, events, ptr, true); + } + + /** + * This function deletes a file descriptor that is listened for events. + * + * @param fd file descriptor to delete from epoll + */ + void Delete(int fd) { + int status = epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, NULL); + // epoll_ctl can return an error on our logical error or on irrecoverable + // error. There is a third possibility that some system limit is reached. In + // that case we could return an erorr and close connection. Chances of + // reaching system limit in normally working memgraph is extremely unlikely, + // so it is correct to terminate even in that case. + CHECK(!status) << "Error on epoll delete: (" << errno << ") " + << strerror(errno); + } + + /** + * This function waits for events from epoll. + * It can be called from multiple threads, but should be used with care in + * that case, see: + * https://stackoverflow.com/questions/7058737/is-epoll-thread-safe + * + * @param fd file descriptor to delete from epoll + */ int Wait(Event *events, int max_events, int timeout) { auto num_events = epoll_wait(epoll_fd_, events, max_events, timeout); // If this check fails there was logical error in our code. CHECK(num_events != -1 || errno == EINTR) - << "Error on epoll_wait, errno: " << errno - << ", message: " << strerror(errno); + << "Error on epoll wait: (" << errno << ") " << strerror(errno); // num_events can be -1 if errno was EINTR (epoll_wait interrupted by signal // handler). We treat that as no events, so we return 0. return num_events == -1 ? 0 : num_events; diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp index 3320d603a..52cca52e9 100644 --- a/src/io/network/socket.cpp +++ b/src/io/network/socket.cpp @@ -193,8 +193,7 @@ std::experimental::optional Socket::Accept() { return Socket(sfd, endpoint); } -bool Socket::Write(const uint8_t *data, size_t len, bool have_more, - const std::function &keep_retrying) { +bool Socket::Write(const uint8_t *data, size_t len, bool have_more) { // MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a // connection dies mid-write, the socket will only return an EPIPE error. int flags = MSG_NOSIGNAL | (have_more ? MSG_MORE : 0); @@ -205,12 +204,8 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more, // Terminal error, return failure. return false; } - // TODO: This can still cause timed out session to continue for a very - // long time. For example if timeout on send is 1 second and after every - // sencond we succeed in writing only one byte that this function can - // block for len seconds. Change semantics of keep_retrying function so - // that this check can be done in while loop even if send succeeds. - if (!keep_retrying()) return false; + // Non-fatal error, retry. + continue; } else { len -= written; data += written; @@ -219,13 +214,12 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more, return true; } -bool Socket::Write(const std::string &s, bool have_more, - const std::function &keep_retrying) { - return Write(reinterpret_cast(s.data()), s.size(), have_more, - keep_retrying); +bool Socket::Write(const std::string &s, bool have_more) { + return Write(reinterpret_cast(s.data()), s.size(), + have_more); } -int Socket::Read(void *buffer, size_t len) { - return read(socket_, buffer, len); +int Socket::Read(void *buffer, size_t len, bool nonblock) { + return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0); } } // namespace io::network diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp index 6b8f8c08d..abc47968a 100644 --- a/src/io/network/socket.hpp +++ b/src/io/network/socket.hpp @@ -132,23 +132,13 @@ class Socket { * @param len length of char* or uint8_t* data * @param have_more set to true if you plan to send more data to allow the * kernel to buffer the data instead of immediately sending it out - * @param keep_retrying while function executes to true socket will retry to - * write data if nonterminal error occurred on socket (EAGAIN, EWOULDBLOCK, - * EINTR)... useful if socket is in nonblocking mode or timeout is set on a - * socket. By default Write doesn't retry if any error occurrs. - * - * TODO: Logic for retrying can be in derived class or in a wrapper of this - * class, unfortunately from current return value we don't know what error - * occured nor how much data was written. * * @return write success status: * true if write succeeded * false if write failed */ - bool Write(const uint8_t *data, size_t len, bool have_more = false, - const std::function &keep_retrying = [] { return false; }); - bool Write(const std::string &s, bool have_more = false, - const std::function &keep_retrying = [] { return false; }); + bool Write(const uint8_t *data, size_t len, bool have_more = false); + bool Write(const std::string &s, bool have_more = false); /** * Read data from the socket. @@ -156,17 +146,14 @@ class Socket { * * @param buffer pointer to the read buffer * @param len length of the read buffer + * @param nonblock set to true if you want a non-blocking read * * @return read success status: * > 0 if data was read, means number of read bytes * == 0 if the client closed the connection * < 0 if an error has occurred */ - // TODO: Return type should be something like StatusOr which would return - // number of read bytes if read succeeded and error code and error message - // otherwise (deduced from errno). We can implement that type easily on top of - // std::variant once c++17 becomes available in memgraph. - int Read(void *buffer, size_t len); + int Read(void *buffer, size_t len, bool nonblock = false); private: Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {} diff --git a/src/io/network/socket_event_dispatcher.hpp b/src/io/network/socket_event_dispatcher.hpp deleted file mode 100644 index 2474203ac..000000000 --- a/src/io/network/socket_event_dispatcher.hpp +++ /dev/null @@ -1,95 +0,0 @@ -#pragma once - -#include - -#include "io/network/epoll.hpp" -#include "utils/crtp.hpp" - -namespace io::network { - -/** - * This class listens to events on an epoll object and calls - * callback functions to process them. - */ - -template -class SocketEventDispatcher { - public: - explicit SocketEventDispatcher( - std::vector> &listeners, uint32_t flags = 0) - : epoll_(flags), listeners_(listeners) {} - - void AddListener(int fd, Listener &listener, uint32_t events) { - // Add the listener associated to fd file descriptor to epoll. - epoll_event event; - event.events = events; - event.data.ptr = &listener; - epoll_.Add(fd, &event); - } - - // Returns true if there was event before timeout. - bool WaitAndProcessEvents() { - // Waits for an event/multiple events and returns a maximum of max_events - // and stores them in the events array. It waits for wait_timeout - // milliseconds. If wait_timeout is achieved, returns 0. - const auto n = epoll_.Wait(events_, kMaxEvents, 200); - DLOG_IF(INFO, n > 0) << "number of events: " << n; - - // Go through all events and process them in order. - for (int i = 0; i < n; ++i) { - auto &event = events_[i]; - Listener *listener = reinterpret_cast(event.data.ptr); - - // Check if the listener that we got in the epoll event still exists - // because it might have been deleted in a previous call. - auto it = - std::find_if(listeners_.begin(), listeners_.end(), - [&](const auto &l) { return l.get() == listener; }); - // If the listener doesn't exist anymore just ignore the event. - if (it == listeners_.end()) continue; - - // Even though it is possible for multiple events to be reported we handle - // only one of them. Since we use epoll in level triggered mode - // unprocessed events will be reported next time we call epoll_wait. This - // kind of processing events is safer since callbacks can destroy listener - // and calling next callback on listener object will segfault. More subtle - // bugs are also possible: one callback can handle multiple events - // (maybe even in a subtle implicit way) and then we don't want to call - // multiple callbacks since we are not sure if those are valid anymore. - try { - if (event.events & EPOLLIN) { - // We have some data waiting to be read. - listener->OnData(); - continue; - } - - if (event.events & EPOLLRDHUP) { - listener->OnClose(); - continue; - } - - // There was an error on the server side. - if (!(event.events & EPOLLIN) || event.events & (EPOLLHUP | EPOLLERR)) { - listener->OnError(); - continue; - } - - } catch (const std::exception &e) { - listener->OnException(e); - } - } - - return n > 0; - } - - private: - static const int kMaxEvents = 64; - // TODO: epoll is really ridiculous here. We don't plan to handle thousands of - // connections so ppoll would actually be better or (even plain nonblocking - // socket). - Epoll epoll_; - Epoll::Event events_[kMaxEvents]; - std::vector> &listeners_; -}; - -} // namespace io::network diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index ec56b576d..eecabe05d 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -107,7 +107,7 @@ void MasterMain() { database::Master db; SessionData session_data{db}; ServerT server({FLAGS_interface, static_cast(FLAGS_port)}, - session_data, FLAGS_num_workers); + session_data, true, FLAGS_num_workers); // Handler for regular termination signals auto shutdown = [&server] { @@ -135,7 +135,7 @@ void SingleNodeMain() { database::SingleNode db; SessionData session_data{db}; ServerT server({FLAGS_interface, static_cast(FLAGS_port)}, - session_data, FLAGS_num_workers); + session_data, true, FLAGS_num_workers); // Handler for regular termination signals auto shutdown = [&server] { diff --git a/tests/concurrent/network_common.hpp b/tests/concurrent/network_common.hpp index bb86834ea..6f02de873 100644 --- a/tests/concurrent/network_common.hpp +++ b/tests/concurrent/network_common.hpp @@ -30,7 +30,6 @@ class TestSession { event_.data.ptr = this; } - bool Alive() const { return socket_.IsOpen(); } bool TimedOut() const { return false; } int Id() const { return socket_.fd(); } @@ -53,11 +52,6 @@ class TestSession { void Written(size_t len) { buffer_.Written(len); } - void Close() { - DLOG(INFO) << "Close session!"; - this->socket_.Close(); - } - Socket &socket() { return socket_; } void RefreshLastEventTime( diff --git a/tests/concurrent/network_read_hang.cpp b/tests/concurrent/network_read_hang.cpp index a1be78a01..392ea885c 100644 --- a/tests/concurrent/network_read_hang.cpp +++ b/tests/concurrent/network_read_hang.cpp @@ -31,8 +31,6 @@ class TestSession { event_.data.ptr = this; } - bool Alive() const { return socket_.IsOpen(); } - bool TimedOut() const { return false; } int Id() const { return socket_.fd(); } @@ -43,8 +41,6 @@ class TestSession { void Written(size_t len) { buffer_.Written(len); } - void Close() { this->socket_.Close(); } - Socket &socket() { return socket_; } void RefreshLastEventTime( @@ -86,7 +82,7 @@ TEST(Network, SocketReadHangOnConcurrentConnections) { TestData data; int N = (std::thread::hardware_concurrency() + 1) / 2; int Nc = N * 3; - communication::Server server(endpoint, data, N); + communication::Server server(endpoint, data, false, N); const auto &ep = server.endpoint(); // start clients diff --git a/tests/concurrent/network_server.cpp b/tests/concurrent/network_server.cpp index e130fa405..601f1b757 100644 --- a/tests/concurrent/network_server.cpp +++ b/tests/concurrent/network_server.cpp @@ -21,7 +21,7 @@ TEST(Network, Server) { // initialize server TestData session_data; int N = (std::thread::hardware_concurrency() + 1) / 2; - ServerT server(endpoint, session_data, N); + ServerT server(endpoint, session_data, false, N); const auto &ep = server.endpoint(); // start clients diff --git a/tests/concurrent/network_session_leak.cpp b/tests/concurrent/network_session_leak.cpp index 98284d71d..fa5f2400b 100644 --- a/tests/concurrent/network_session_leak.cpp +++ b/tests/concurrent/network_session_leak.cpp @@ -19,11 +19,10 @@ TEST(Network, SessionLeak) { // initialize listen socket Endpoint endpoint(interface, 0); - std::cout << endpoint << std::endl; // initialize server TestData session_data; - ServerT server(endpoint, session_data, 2); + ServerT server(endpoint, session_data, false, 2); // start clients int N = 50; diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index fde4c8bc4..e11266d33 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -20,9 +20,6 @@ class TestSocket { TestSocket(TestSocket &&) = default; TestSocket &operator=(TestSocket &&) = default; - void Close() { socket_ = -1; } - bool IsOpen() { return socket_ != -1; } - int id() const { return socket_; } bool Write(const uint8_t *data, size_t len, bool have_more = false, diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 75d89c495..6d8ad0942 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -16,6 +16,7 @@ std::vector &output = session.socket().output; using communication::bolt::SessionData; +using communication::bolt::SessionException; using communication::bolt::State; using SessionT = communication::bolt::Session; using ResultStreamT = SessionT::ResultStreamT; @@ -89,7 +90,6 @@ void ExecuteHandshake(SessionT &session, std::vector &output) { session.Execute(); ASSERT_EQ(session.state_, State::Init); - ASSERT_TRUE(session.socket().IsOpen()); PrintOutput(output); CheckOutput(output, handshake_resp, 4); } @@ -109,7 +109,6 @@ void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len, void ExecuteInit(SessionT &session, std::vector &output) { ExecuteCommand(session, init_req, sizeof(init_req)); ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.socket().IsOpen()); PrintOutput(output); CheckOutput(output, init_resp, 7); } @@ -149,10 +148,9 @@ TEST(BoltSession, HandshakeWrongPreamble) { // copy 0x00000001 four times for (int i = 0; i < 4; ++i) memcpy(buff.data + i * 4, handshake_req + 4, 4); session.Written(20); - session.Execute(); + ASSERT_THROW(session.Execute(), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); PrintOutput(output); CheckFailureMessage(output); } @@ -166,14 +164,12 @@ TEST(BoltSession, HandshakeInTwoPackets) { session.Execute(); ASSERT_EQ(session.state_, State::Handshake); - ASSERT_TRUE(session.socket().IsOpen()); memcpy(buff.data + 10, handshake_req + 10, 10); session.Written(10); session.Execute(); ASSERT_EQ(session.state_, State::Init); - ASSERT_TRUE(session.socket().IsOpen()); PrintOutput(output); CheckOutput(output, handshake_resp, 4); } @@ -181,10 +177,11 @@ TEST(BoltSession, HandshakeInTwoPackets) { TEST(BoltSession, HandshakeWriteFail) { INIT_VARS; session.socket().SetWriteSuccess(false); - ExecuteCommand(session, handshake_req, sizeof(handshake_req), false); + ASSERT_THROW( + ExecuteCommand(session, handshake_req, sizeof(handshake_req), false), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -196,10 +193,10 @@ TEST(BoltSession, HandshakeOK) { TEST(BoltSession, InitWrongSignature) { INIT_VARS; ExecuteHandshake(session, output); - ExecuteCommand(session, run_req_header, sizeof(run_req_header)); + ASSERT_THROW(ExecuteCommand(session, run_req_header, sizeof(run_req_header)), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -209,10 +206,9 @@ TEST(BoltSession, InitWrongMarker) { // wrong marker, good signature uint8_t data[2] = {0x00, init_req[1]}; - ExecuteCommand(session, data, 2); + ASSERT_THROW(ExecuteCommand(session, data, 2), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -224,10 +220,9 @@ TEST(BoltSession, InitMissingData) { for (int i = 0; i < 3; ++i) { INIT_VARS; ExecuteHandshake(session, output); - ExecuteCommand(session, init_req, len[i]); + ASSERT_THROW(ExecuteCommand(session, init_req, len[i]), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } } @@ -236,10 +231,10 @@ TEST(BoltSession, InitWriteFail) { INIT_VARS; ExecuteHandshake(session, output); session.socket().SetWriteSuccess(false); - ExecuteCommand(session, init_req, sizeof(init_req)); + ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -257,10 +252,9 @@ TEST(BoltSession, ExecuteRunWrongMarker) { // wrong marker, good signature uint8_t data[2] = {0x00, run_req_header[1]}; - ExecuteCommand(session, data, sizeof(data)); + ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -273,10 +267,10 @@ TEST(BoltSession, ExecuteRunMissingData) { INIT_VARS; ExecuteHandshake(session, output); ExecuteInit(session, output); - ExecuteCommand(session, run_req_header, len[i]); + ASSERT_THROW(ExecuteCommand(session, run_req_header, len[i]), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } } @@ -291,15 +285,17 @@ TEST(BoltSession, ExecuteRunBasicException) { session.socket().SetWriteSuccess(i == 0); WriteRunRequest(session, "MATCH (omnom"); - session.Execute(); + if (i == 0) { + session.Execute(); + } else { + ASSERT_THROW(session.Execute(), SessionException); + } if (i == 0) { ASSERT_EQ(session.state_, State::ErrorIdle); - ASSERT_TRUE(session.socket().IsOpen()); CheckFailureMessage(output); } else { ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); } } @@ -330,10 +326,9 @@ TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) { // wrong marker, good signature uint8_t data[2] = {0x00, dataset[i][1]}; - ExecuteCommand(session, data, sizeof(data)); + ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } } @@ -347,10 +342,10 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) { ExecuteInit(session, output); session.socket().SetWriteSuccess(i == 0); - ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + ASSERT_THROW(ExecuteCommand(session, pullall_req, sizeof(pullall_req)), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); if (i == 0) { CheckFailureMessage(output); } else { @@ -377,16 +372,18 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) { if (j == 1) output.clear(); session.socket().SetWriteSuccess(j == 0); - ExecuteCommand(session, dataset[i], 2); + if (j == 0) { + ExecuteCommand(session, dataset[i], 2); + } else { + ASSERT_THROW(ExecuteCommand(session, dataset[i], 2), SessionException); + } if (j == 0) { ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.socket().IsOpen()); ASSERT_FALSE(session.encoder_buffer_.HasData()); PrintOutput(output); } else { ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); } } @@ -398,10 +395,10 @@ TEST(BoltSession, ExecuteInvalidMessage) { ExecuteHandshake(session, output); ExecuteInit(session, output); - ExecuteCommand(session, init_req, sizeof(init_req)); + ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -419,18 +416,21 @@ TEST(BoltSession, ErrorIgnoreMessage) { output.clear(); session.socket().SetWriteSuccess(i == 0); - ExecuteCommand(session, init_req, sizeof(init_req)); + if (i == 0) { + ExecuteCommand(session, init_req, sizeof(init_req)); + } else { + ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)), + SessionException); + } // assert that all data from the init message was cleaned up ASSERT_EQ(session.decoder_buffer_.Size(), 0); if (i == 0) { ASSERT_EQ(session.state_, State::ErrorIdle); - ASSERT_TRUE(session.socket().IsOpen()); CheckOutput(output, ignored_resp, sizeof(ignored_resp)); } else { ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); } } @@ -455,10 +455,9 @@ TEST(BoltSession, ErrorRunAfterRun) { // New run request. WriteRunRequest(session, "MATCH (n) RETURN n"); - session.Execute(); + ASSERT_THROW(session.Execute(), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); } TEST(BoltSession, ErrorCantCleanup) { @@ -473,10 +472,10 @@ TEST(BoltSession, ErrorCantCleanup) { output.clear(); // there is data missing in the request, cleanup should fail - ExecuteCommand(session, init_req, sizeof(init_req) - 10); + ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req) - 10), + SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -493,10 +492,9 @@ TEST(BoltSession, ErrorWrongMarker) { // wrong marker, good signature uint8_t data[2] = {0x00, init_req[1]}; - ExecuteCommand(session, data, sizeof(data)); + ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -518,18 +516,20 @@ TEST(BoltSession, ErrorOK) { output.clear(); session.socket().SetWriteSuccess(j == 0); - ExecuteCommand(session, dataset[i], 2); + if (j == 0) { + ExecuteCommand(session, dataset[i], 2); + } else { + ASSERT_THROW(ExecuteCommand(session, dataset[i], 2), SessionException); + } // assert that all data from the init message was cleaned up ASSERT_EQ(session.decoder_buffer_.Size(), 0); if (j == 0) { ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.socket().IsOpen()); CheckOutput(output, success_resp, sizeof(success_resp)); } else { ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); } } @@ -549,10 +549,9 @@ TEST(BoltSession, ErrorMissingData) { // some marker, missing signature uint8_t data[1] = {0x00}; - ExecuteCommand(session, data, sizeof(data)); + ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } @@ -566,7 +565,6 @@ TEST(BoltSession, MultipleChunksInOneExecute) { ExecuteCommand(session, pullall_req, sizeof(pullall_req)); ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.socket().IsOpen()); PrintOutput(output); // Count chunks in output @@ -598,15 +596,13 @@ TEST(BoltSession, PartialChunk) { session.Execute(); ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.socket().IsOpen()); ASSERT_EQ(output.size(), 0); WriteChunkTail(session); - session.Execute(); + ASSERT_THROW(session.Execute(), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.socket().IsOpen()); ASSERT_GT(output.size(), 0); PrintOutput(output); } @@ -657,8 +653,6 @@ TEST(BoltSession, ExplicitTransactionValidQueries) { ASSERT_EQ(session.state_, State::Idle); ASSERT_FALSE(session.db_accessor_); CheckSuccessMessage(output); - - ASSERT_TRUE(session.socket().IsOpen()); } } @@ -702,25 +696,22 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) { CheckSuccessMessage(output); WriteRunRequest(session, transaction_end.c_str()); - session.Execute(); if (transaction_end == "ROLLBACK") { + session.Execute(); ASSERT_EQ(session.state_, State::Result); ASSERT_FALSE(session.db_accessor_); - ASSERT_TRUE(session.socket().IsOpen()); CheckSuccessMessage(output); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); ASSERT_EQ(session.state_, State::Idle); ASSERT_FALSE(session.db_accessor_); - ASSERT_TRUE(session.socket().IsOpen()); CheckSuccessMessage(output); } else { + ASSERT_THROW(session.Execute(), SessionException); ASSERT_EQ(session.state_, State::Close); - ASSERT_FALSE(session.db_accessor_); - ASSERT_FALSE(session.socket().IsOpen()); CheckFailureMessage(output); } } diff --git a/tests/unit/network_timeouts.cpp b/tests/unit/network_timeouts.cpp index a803d07b1..600fcb702 100644 --- a/tests/unit/network_timeouts.cpp +++ b/tests/unit/network_timeouts.cpp @@ -31,7 +31,7 @@ class RunningServer { database::SingleNode db_; SessionData session_data_{db_}; Endpoint endpoint_{"127.0.0.1", 0}; - ServerT server_{endpoint_, session_data_, 1}; + ServerT server_{endpoint_, session_data_, true, 1}; }; class TestClient : public ClientT { @@ -48,7 +48,7 @@ class TestClient : public ClientT { TEST(NetworkTimeouts, InactiveSession) { FLAGS_query_execution_time_sec = 60; - FLAGS_session_inactivity_timeout = 1; + FLAGS_session_inactivity_timeout = 2; RunningServer rs; TestClient client(rs.server_.endpoint()); @@ -68,12 +68,12 @@ TEST(NetworkTimeouts, InactiveSession) { client.Execute("RETURN 1", {}); // After sleep, session should have timed out. - std::this_thread::sleep_for(1500ms); + std::this_thread::sleep_for(3500ms); EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException); } TEST(NetworkTimeouts, TimeoutInMultiCommandTransaction) { - FLAGS_query_execution_time_sec = 1; + FLAGS_query_execution_time_sec = 2; FLAGS_session_inactivity_timeout = 60; RunningServer rs; @@ -88,7 +88,7 @@ TEST(NetworkTimeouts, TimeoutInMultiCommandTransaction) { client.Execute("RETURN 1", {}); // Session shouldn't be alive anymore. - std::this_thread::sleep_for(2s); + std::this_thread::sleep_for(4s); EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException); }