diff --git a/.ycm_extra_conf.py b/.ycm_extra_conf.py index 2e9a0b650..bcc9950c6 100644 --- a/.ycm_extra_conf.py +++ b/.ycm_extra_conf.py @@ -31,7 +31,8 @@ BASE_FLAGS = [ '-I./libs/gflags/include', '-I./experimental/distributed/src', '-I./experimental/distributed/libs/cereal/include', - '-I./libs/postgresql/include' + '-I./libs/postgresql/include', + '-I./build/include' ] SOURCE_EXTENSIONS = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index b4038cfcf..dc766dd33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,6 +180,7 @@ target_link_libraries(antlr_opencypher_parser_lib antlr4) # all memgraph src files set(memgraph_src_files ${src_dir}/communication/bolt/v1/decoder/decoded_value.cpp + ${src_dir}/communication/bolt/v1/session.cpp ${src_dir}/data_structures/concurrent/skiplist_gc.cpp ${src_dir}/database/graph_db.cpp ${src_dir}/database/graph_db_accessor.cpp diff --git a/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp b/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp index 08a538879..dc2f19433 100644 --- a/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp +++ b/src/communication/bolt/v1/decoder/chunked_decoder_buffer.hpp @@ -116,7 +116,6 @@ class ChunkedDecoderBuffer { return ChunkState::Partial; } - data_.reserve(data_.size() + chunk_size); std::copy(data + 2, data + chunk_size + 2, std::back_inserter(data_)); buffer_.Shift(chunk_size + 2); diff --git a/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp b/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp index 3159886a0..460f2e05f 100644 --- a/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp +++ b/src/communication/bolt/v1/encoder/chunked_encoder_buffer.hpp @@ -92,7 +92,6 @@ class ChunkedEncoderBuffer { // 3. Copy whole chunk into the buffer. size_ += pos_; - buffer_.reserve(size_); std::copy(chunk_.begin(), chunk_.begin() + pos_, std::back_inserter(buffer_)); diff --git a/src/communication/bolt/v1/session.cpp b/src/communication/bolt/v1/session.cpp new file mode 100644 index 000000000..a5163830d --- /dev/null +++ b/src/communication/bolt/v1/session.cpp @@ -0,0 +1,7 @@ +#include "communication/bolt/v1/session.hpp" + +// Danger: If multiple sessions are associated with one worker nonactive +// sessions could become blocked for FLAGS_session_inactivity_timeout time. +// TODO: We should never associate more sessions with one worker. +DEFINE_int32(session_inactivity_timeout, 1800, + "Time in seconds after which inactive sessions will be closed"); diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index 14bb2e0cb..4564b8876 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -4,24 +4,24 @@ #include "io/network/epoll.hpp" #include "io/network/socket.hpp" +#include "io/network/stream_buffer.hpp" #include "database/dbms.hpp" #include "query/interpreter.hpp" #include "transactions/transaction.hpp" #include "communication/bolt/v1/constants.hpp" +#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp" +#include "communication/bolt/v1/decoder/decoder.hpp" +#include "communication/bolt/v1/encoder/encoder.hpp" +#include "communication/bolt/v1/encoder/result_stream.hpp" #include "communication/bolt/v1/state.hpp" #include "communication/bolt/v1/states/error.hpp" #include "communication/bolt/v1/states/executing.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" -#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp" -#include "communication/bolt/v1/decoder/decoder.hpp" -#include "communication/bolt/v1/encoder/encoder.hpp" -#include "communication/bolt/v1/encoder/result_stream.hpp" - -#include "io/network/stream_buffer.hpp" +DECLARE_int32(session_inactivity_timeout); namespace communication::bolt { @@ -30,11 +30,7 @@ namespace communication::bolt { * * This class is responsible for holding references to Dbms and Interpreter * that are passed through the network server and worker to the session. - * - * @tparam OutputStream type of output stream (could be a bolt output stream or - * a test output stream) */ -template <typename OutputStream> struct SessionData { Dbms dbms; query::Interpreter interpreter; @@ -45,32 +41,40 @@ struct SessionData { * * This class is responsible for handling a single client connection. * - * @tparam Socket type of socket (could be a network socket or test socket) + * @tparam TSocket type of socket (could be a network socket or test socket) */ -template <typename Socket> +template <typename TSocket> class Session { - private: - using OutputStream = ResultStream<Encoder<ChunkedEncoderBuffer<Socket>>>; + public: + // Wrapper around socket that checks if session has timed out on write + // failures, used in encoder buffer. + class TimeoutSocket { + public: + explicit TimeoutSocket(Session &session) : session_(session) {} + + bool Write(const uint8_t *data, size_t len) { + return session_.socket_.Write(data, len, + [this] { return !session_.TimedOut(); }); + } + + private: + Session &session_; + }; + + using ResultStreamT = + ResultStream<Encoder<ChunkedEncoderBuffer<TimeoutSocket>>>; using StreamBuffer = io::network::StreamBuffer; - public: - Session(Socket &&socket, SessionData<OutputStream> &data) + Session(TSocket &&socket, SessionData &data) : socket_(std::move(socket)), dbms_(data.dbms), - interpreter_(data.interpreter) { - event_.data.ptr = this; - } + interpreter_(data.interpreter) {} ~Session() { DCHECK(!db_accessor_) << "Transaction should have already be closed in Close"; } - /** - * @return is the session in a valid state - */ - bool Alive() const { return state_ != State::Close; } - /** * @return the socket id */ @@ -154,6 +158,19 @@ class Session { */ void Written(size_t len) { buffer_.Written(len); } + /** + * Returns true if session has timed out. Session times out if there was no + * activity in FLAGS_sessions_inactivity_timeout seconds or if there is a + * active transaction with shoul_abort flag set to true. + */ + bool TimedOut() const { + return db_accessor_ + ? db_accessor_->should_abort() + : last_event_time_ + std::chrono::seconds( + FLAGS_session_inactivity_timeout) < + std::chrono::steady_clock::now(); + } + /** * Closes the session (client socket). */ @@ -183,37 +200,40 @@ class Session { db_accessor_ = nullptr; } - // TODO: Rethink if there is a way to hide some members. At the momemnt all of - // them are public. - Socket socket_; + // TODO: Rethink if there is a way to hide some members. At the momement all + // of them are public. + TSocket socket_; Dbms &dbms_; query::Interpreter &interpreter_; - ChunkedEncoderBuffer<Socket> encoder_buffer_{socket_}; - Encoder<ChunkedEncoderBuffer<Socket>> encoder_{encoder_buffer_}; - OutputStream output_stream_{encoder_}; + TimeoutSocket timeout_socket_{*this}; + ChunkedEncoderBuffer<TimeoutSocket> encoder_buffer_{timeout_socket_}; + Encoder<ChunkedEncoderBuffer<TimeoutSocket>> encoder_{encoder_buffer_}; + ResultStreamT output_stream_{encoder_}; Buffer<> buffer_; ChunkedDecoderBuffer decoder_buffer_{buffer_}; Decoder<ChunkedDecoderBuffer> decoder_{decoder_buffer_}; - io::network::Epoll::Event event_; bool handshake_done_{false}; State state_{State::Handshake}; - // GraphDbAccessor of active transaction in the session, can be null if there - // is no associated transaction. + // GraphDbAccessor of active transaction in the session, can be null if + // there is no associated transaction. std::unique_ptr<GraphDbAccessor> db_accessor_; + // Time of the last event. + std::chrono::time_point<std::chrono::steady_clock> last_event_time_ = + std::chrono::steady_clock::now(); private: void ClientFailureInvalidData() { - // set the state to Close + // Set the state to Close. state_ = State::Close; - // don't care about the return status because this is always - // called when we are about to close the connection to the client + // We don't care about the return status because this is called when we + // are about to close the connection to the client. encoder_buffer_.Clear(); encoder_.MessageFailure({{"code", "Memgraph.InvalidData"}, {"message", "The client has sent invalid data!"}}); - // close the connection + // Close the connection. Close(); } }; diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index c12b73490..500e531be 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -16,8 +16,8 @@ namespace communication::bolt { * The error state is exited upon receiving an ACK_FAILURE or RESET message. * @param session the session that should be used for the run */ -template <typename Session> -State StateErrorRun(Session &session, State state) { +template <typename TSession> +State StateErrorRun(TSession &session, State state) { Marker marker; Signature signature; if (!session.decoder_.ReadMessageHeader(&signature, &marker)) { @@ -60,16 +60,16 @@ State StateErrorRun(Session &session, State state) { } else { uint8_t value = underlying_cast(marker); - // all bolt client messages have less than 15 parameters - // so if we receive anything than a TinyStruct it's an error + // All bolt client messages have less than 15 parameters so if we receive + // anything than a TinyStruct it's an error. if ((value & 0xF0) != underlying_cast(Marker::TinyStruct)) { DLOG(WARNING) << fmt::format( "Expected TinyStruct marker, but received 0x{:02X}!", value); return State::Close; } - // we need to clean up all parameters from this command - value &= 0x0F; // the length is stored in the lower nibble + // We need to clean up all parameters from this command. + value &= 0x0F; // The length is stored in the lower nibble. DecodedValue dv; for (int i = 0; i < value; ++i) { if (!session.decoder_.ReadValue(&dv)) { @@ -79,13 +79,13 @@ State StateErrorRun(Session &session, State state) { } } - // ignore this message + // Ignore this message. if (!session.encoder_.MessageIgnored()) { DLOG(WARNING) << "Couldn't send ignored message!"; return State::Close; } - // cleanup done, command ignored, stay in error state + // Cleanup done, command ignored, stay in error state. return state; } } diff --git a/src/communication/bolt/v1/states/executing.hpp b/src/communication/bolt/v1/states/executing.hpp index 9417a5819..6d98cedd8 100644 --- a/src/communication/bolt/v1/states/executing.hpp +++ b/src/communication/bolt/v1/states/executing.hpp @@ -15,8 +15,8 @@ namespace communication::bolt { -template <typename Session> -State HandleRun(Session &session, State state, Marker marker) { +template <typename TSession> +State HandleRun(TSession &session, State state, Marker marker) { const std::map<std::string, query::TypedValue> kEmptyFields = { {"fields", std::vector<query::TypedValue>{}}}; diff --git a/src/communication/bolt/v1/states/handshake.hpp b/src/communication/bolt/v1/states/handshake.hpp index 1394e5ea2..833525909 100644 --- a/src/communication/bolt/v1/states/handshake.hpp +++ b/src/communication/bolt/v1/states/handshake.hpp @@ -2,6 +2,7 @@ #include <glog/logging.h> +#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/constants.hpp" #include "communication/bolt/v1/state.hpp" @@ -12,8 +13,8 @@ namespace communication::bolt { * This function runs everything to make a Bolt handshake with the client. * @param session the session that should be used for the run */ -template <typename Session> -State StateHandshakeRun(Session &session) { +template <typename TSession> +State StateHandshakeRun(TSession &session) { auto precmp = memcmp(session.buffer_.data(), kPreamble, sizeof(kPreamble)); if (UNLIKELY(precmp != 0)) { DLOG(WARNING) << "Received a wrong preamble!"; @@ -21,10 +22,10 @@ State StateHandshakeRun(Session &session) { } // TODO so far we only support version 1 of the protocol so it doesn't - // make sense to check which version the client prefers - // this will change in the future + // make sense to check which version the client prefers this will change in + // the future. - if (!session.socket_.Write(kProtocol, sizeof(kProtocol))) { + if (!session.timeout_socket_.Write(kProtocol, sizeof(kProtocol))) { DLOG(WARNING) << "Couldn't write handshake response!"; return State::Close; } diff --git a/src/communication/bolt/v1/states/init.hpp b/src/communication/bolt/v1/states/init.hpp index 721277e13..c3829d093 100644 --- a/src/communication/bolt/v1/states/init.hpp +++ b/src/communication/bolt/v1/states/init.hpp @@ -12,9 +12,9 @@ namespace communication::bolt { /** - * Init state run function + * Init state run function. * This function runs everything to initialize a Bolt session with the client. - * @param session the session that should be used for the run + * @param session the session that should be used for the run. */ template <typename Session> State StateInitRun(Session &session) { diff --git a/src/communication/server.hpp b/src/communication/server.hpp index 73e127467..bc77df4d2 100644 --- a/src/communication/server.hpp +++ b/src/communication/server.hpp @@ -11,6 +11,7 @@ #include <glog/logging.h> #include "communication/worker.hpp" +#include "io/network/socket.hpp" #include "io/network/socket_event_dispatcher.hpp" namespace communication { @@ -27,30 +28,44 @@ namespace communication { * Current Server achitecture: * incomming connection -> server -> worker -> session * - * @tparam Session the server can handle different Sessions, each session + * @tparam TSession the server can handle different Sessions, each session * represents a different protocol so the same network infrastructure * can be used for handling different protocols - * @tparam Socket the input/output socket that should be used - * @tparam SessionData the class with objects that will be forwarded to the + * @tparam TSessionData the class with objects that will be forwarded to the * session */ -// TODO: Remove Socket templatisation. Socket requirements are very specific. -// It needs to be in non blocking mode, etc. -template <typename Session, typename Socket, typename SessionData> +template <typename TSession, typename TSessionData> class Server { public: - using worker_t = Worker<Session, Socket, SessionData>; + using WorkerT = Worker<TSession, TSessionData>; + using Socket = io::network::Socket; - Server(Socket &&socket, SessionData &session_data) - : socket_(std::move(socket)), session_data_(session_data) {} + Server(const io::network::NetworkEndpoint &endpoint, + TSessionData &session_data) + : session_data_(session_data) { + // Without server we can't continue with application so we can just + // terminate here. + if (!socket_.Bind(endpoint)) { + LOG(FATAL) << "Cannot bind to socket on " << endpoint.address() << " at " + << endpoint.port(); + } + if (!socket_.SetNonBlocking()) { + LOG(FATAL) << "Cannot set socket to non blocking!"; + } + if (!socket_.Listen(1024)) { + LOG(FATAL) << "Cannot listen on socket!"; + } + } + + const auto &endpoint() const { return socket_.endpoint(); } void Start(size_t n) { std::cout << fmt::format("Starting {} workers", n) << std::endl; workers_.reserve(n); for (size_t i = 0; i < n; ++i) { - workers_.push_back(std::make_unique<worker_t>(session_data_)); + workers_.push_back(std::make_unique<WorkerT>(session_data_)); worker_threads_.emplace_back( - [this](worker_t &worker) -> void { worker.Start(alive_); }, + [this](WorkerT &worker) -> void { worker.Start(alive_); }, std::ref(*workers_.back())); } std::cout << "Server is fully armed and operational" << std::endl; @@ -81,15 +96,14 @@ class Server { private: class ConnectionAcceptor : public io::network::BaseListener { public: - ConnectionAcceptor(Socket &socket, - Server<Session, Socket, SessionData> &server) + ConnectionAcceptor(Socket &socket, Server<TSession, TSessionData> &server) : io::network::BaseListener(socket), server_(server) {} void OnData() { DCHECK(server_.idx_ < server_.workers_.size()) << "Invalid worker id."; DLOG(INFO) << "On connect"; auto connection = AcceptConnection(); - if (UNLIKELY(!connection)) { + if (!connection) { // Connection is not available anymore or configuration failed. return; } @@ -112,21 +126,22 @@ class Server { s->fd(), s->endpoint().address(), s->endpoint().family(), s->endpoint().port()); + if (!s->SetTimeout(1, 0)) return std::experimental::nullopt; if (!s->SetKeepAlive()) return std::experimental::nullopt; if (!s->SetNoDelay()) return std::experimental::nullopt; return s; } - Server<Session, Socket, SessionData> &server_; + Server<TSession, TSessionData> &server_; }; - std::vector<std::unique_ptr<worker_t>> workers_; + std::vector<std::unique_ptr<WorkerT>> workers_; std::vector<std::thread> worker_threads_; std::atomic<bool> alive_{true}; int idx_{0}; Socket socket_; - SessionData &session_data_; + TSessionData &session_data_; }; } // namespace communication diff --git a/src/communication/worker.hpp b/src/communication/worker.hpp index b05370bc3..7e597fc27 100644 --- a/src/communication/worker.hpp +++ b/src/communication/worker.hpp @@ -2,6 +2,7 @@ #include <algorithm> #include <atomic> +#include <chrono> #include <cstdio> #include <iomanip> #include <memory> @@ -9,9 +10,11 @@ #include <sstream> #include <thread> +#include <gflags/gflags.h> #include <glog/logging.h> #include "io/network/network_error.hpp" +#include "io/network/socket.hpp" #include "io/network/socket_event_dispatcher.hpp" #include "io/network/stream_buffer.hpp" #include "threading/sync/spinlock.hpp" @@ -27,20 +30,19 @@ namespace communication { * Listens for incomming data on connections and accepts new connections. * Also, executes sessions on incomming data. * - * @tparam Session the worker can handle different Sessions, each session + * @tparam TSession the worker can handle different Sessions, each session * represents a different protocol so the same network infrastructure * can be used for handling different protocols - * @tparam Socket the input/output socket that should be used - * @tparam SessionData the class with objects that will be forwarded to the + * @tparam TSessionData the class with objects that will be forwarded to the * session */ -template <typename Session, typename Socket, typename SessionData> +template <typename TSession, typename TSessionData> class Worker { - using StreamBuffer = io::network::StreamBuffer; + using Socket = io::network::Socket; public: void AddConnection(Socket &&connection) { - std::unique_lock<SpinLock> gurad(lock_); + std::unique_lock<SpinLock> guard(lock_); // Remember fd before moving connection into SessionListener. int fd = connection.fd(); session_listeners_.push_back( @@ -51,11 +53,31 @@ class Worker { EPOLLIN | EPOLLRDHUP); } - Worker(SessionData &session_data) : session_data_(session_data) {} + Worker(TSessionData &session_data) : session_data_(session_data) {} void Start(std::atomic<bool> &alive) { while (alive) { dispatcher_.WaitAndProcessEvents(); + + bool check_sessions_for_timeouts = true; + while (check_sessions_for_timeouts) { + check_sessions_for_timeouts = false; + + std::unique_lock<SpinLock> guard(lock_); + for (auto &session_listener : session_listeners_) { + if (session_listener->session().TimedOut()) { + // We need to unlock here, because OnSessionAndTxTimeout will need + // to acquire same lock. + guard.unlock(); + session_listener->OnSessionAndTxTimeout(); + // Since we released lock we can't continue iteration so we need to + // break. There could still be more sessions that timed out so we + // set check_sessions_for_timeout back to true. + check_sessions_for_timeouts = true; + break; + } + } + } } } @@ -65,76 +87,96 @@ class Worker { class SessionSocketListener : public io::network::BaseListener { public: SessionSocketListener(Socket &&socket, - Worker<Session, Socket, SessionData> &worker) + Worker<TSession, TSessionData> &worker) : BaseListener(session_.socket_), session_(std::move(socket), worker.session_data_), worker_(worker) {} - void OnError() { - LOG(ERROR) << "Error occured in this session"; - OnClose(); - } + auto &session() { return session_; } + const auto &session() const { return session_; } + const auto &TimedOut() const { return session_.TimedOut(); } void OnData() { + session_.last_event_time_ = std::chrono::steady_clock::now(); DLOG(INFO) << "On data"; - - if (UNLIKELY(!session_.Alive())) { - DLOG(WARNING) << "Calling OnClose because the stream isn't alive!"; - OnClose(); - return; - } - // allocate the buffer to fill the data auto buf = session_.Allocate(); - // read from the buffer at most buf.len bytes int len = session_.socket_.Read(buf.data, buf.len); // check for read errors if (len == -1) { - // this means we have read all available data - if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) { + // This means read would block or read was interrupted by signal. + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { return; } - - // some other error occurred, check errno + // Some other error occurred, check errno. OnError(); return; } - - // end of file, the client has closed the connection - if (UNLIKELY(len == 0)) { + // The client has closed the connection. + if (len == 0) { DLOG(WARNING) << "Calling OnClose because the socket is closed!"; OnClose(); return; } - // notify the stream that it has new data + // Notify the stream that it has new data. session_.Written(len); - DLOG(INFO) << "OnRead"; - try { session_.Execute(); } catch (const std::exception &e) { - LOG(ERROR) << "Error occured while executing statement. " << std::endl + LOG(ERROR) << "Error occured while executing statement with message: " << e.what(); - // TODO: report to client + OnError(); } - // TODO: Should we even continue with this session if error occurs while - // reading. + session_.last_event_time_ = std::chrono::steady_clock::now(); + } + + // TODO: Remove duplication in next three functions. + void OnError() { + LOG(ERROR) << fmt::format( + "Error occured in session associated with {}:{}", + session_.socket_.endpoint().address(), + session_.socket_.endpoint().port()); + CloseSession(); + } + + void OnException(const std::exception &e) { + LOG(ERROR) << fmt::format( + "Exception was thrown while processing event in session associated " + "with {}:{} with message: {}", + session_.socket_.endpoint().address(), + session_.socket_.endpoint().port(), e.what()); + CloseSession(); + } + + void OnSessionAndTxTimeout() { + LOG(WARNING) << fmt::format( + "Session or transaction associated with {}:{} timed out.", + session_.socket_.endpoint().address(), + session_.socket_.endpoint().port()); + // TODO: report to client what happend. + CloseSession(); } void OnClose() { LOG(INFO) << fmt::format("Client {}:{} closed the connection.", session_.socket_.endpoint().address(), - session_.socket_.endpoint().port()) - << std::endl; + session_.socket_.endpoint().port()); + CloseSession(); + } + + private: + void CloseSession() { session_.Close(); - std::unique_lock<SpinLock> gurad(worker_.lock_); + + std::unique_lock<SpinLock> guard(worker_.lock_); auto it = std::find_if( worker_.session_listeners_.begin(), worker_.session_listeners_.end(), [&](const auto &l) { return l->session_.Id() == session_.Id(); }); + CHECK(it != worker_.session_listeners_.end()) << "Trying to remove session that is not found in worker's sessions"; int i = it - worker_.session_listeners_.begin(); @@ -142,13 +184,12 @@ class Worker { worker_.session_listeners_.pop_back(); } - private: - Session session_; + TSession session_; Worker &worker_; }; SpinLock lock_; - SessionData &session_data_; + TSessionData &session_data_; io::network::SocketEventDispatcher<SessionSocketListener> dispatcher_; std::vector<std::unique_ptr<SessionSocketListener>> session_listeners_; }; diff --git a/src/io/network/epoll.hpp b/src/io/network/epoll.hpp index ecb30cf00..66c6ddbf2 100644 --- a/src/io/network/epoll.hpp +++ b/src/io/network/epoll.hpp @@ -21,8 +21,7 @@ class Epoll { public: using Event = struct epoll_event; - Epoll(int flags) { - epoll_fd_ = epoll_create1(flags); + Epoll(int flags) : epoll_fd_(epoll_create1(flags)) { // epoll_create1 returns an error if there is a logical error in our code // (for example invalid flags) or if there is irrecoverable error. In both // cases it is best to terminate. @@ -53,6 +52,6 @@ class Epoll { } private: - int epoll_fd_; + const int epoll_fd_; }; } diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp index 90d4bbd5b..d15e80014 100644 --- a/src/io/network/socket.cpp +++ b/src/io/network/socket.cpp @@ -17,6 +17,7 @@ #include <unistd.h> #include "io/network/addrinfo.hpp" +#include "threading/sync/cpu_relax.hpp" #include "utils/likely.hpp" namespace io::network { @@ -47,7 +48,7 @@ void Socket::Close() { socket_ = -1; } -bool Socket::IsOpen() { return socket_ != -1; } +bool Socket::IsOpen() const { return socket_ != -1; } bool Socket::Connect(const NetworkEndpoint &endpoint) { if (socket_ != -1) return false; @@ -194,25 +195,27 @@ std::experimental::optional<Socket> Socket::Accept() { return Socket(sfd, endpoint); } -const NetworkEndpoint &Socket::endpoint() const { return endpoint_; } - -bool Socket::Write(const std::string &str) { - return Write(str.c_str(), str.size()); -} - -bool Socket::Write(const char *data, size_t len) { - return Write(reinterpret_cast<const uint8_t *>(data), len); -} - -bool Socket::Write(const uint8_t *data, size_t len) { +bool Socket::Write(const uint8_t *data, size_t len, + const std::function<bool()> &keep_retrying) { while (len > 0) { - // MSG_NOSIGNAL is here to disable raising a SIGPIPE - // signal when a connection dies mid-write, the socket - // will only return an EPIPE error + // MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a + // connection dies mid-write, the socket will only return an EPIPE error. auto written = send(socket_, data, len, MSG_NOSIGNAL); - if (UNLIKELY(written == -1)) return false; - len -= written; - data += written; + if (written == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + // Terminal error, return failure. + return false; + } + // TODO: This can still cause timed out session to continue for a very + // long time. For example if timeout on send is 1 second and after every + // sencond we succeed in writing only one byte that this function can + // block for len seconds. Change semantics of keep_retrying function so + // that this check can be done in while loop even if send succeeds. + if (!keep_retrying()) return false; + } else { + len -= written; + data += written; + } } return true; } diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp index a1219c336..8a19828d2 100644 --- a/src/io/network/socket.hpp +++ b/src/io/network/socket.hpp @@ -1,6 +1,7 @@ #pragma once #include <experimental/optional> +#include <functional> #include <iostream> #include "io/network/network_endpoint.hpp" @@ -34,7 +35,7 @@ class Socket { * true if the socket is open * false if the socket is closed */ - bool IsOpen(); + bool IsOpen() const; /** * Connects the socket to the specified endpoint. @@ -127,23 +128,29 @@ class Socket { /** * Returns the currently active endpoint of the socket. */ - const NetworkEndpoint &endpoint() const; + const NetworkEndpoint &endpoint() const { return endpoint_; } /** * Write data to the socket. * Theese functions guarantee that all data will be written. * - * @param str std::string to write to the socket - * @param data char* or uint8_t* to data that should be written + * @param data uint8_t* to data that should be written * @param len length of char* or uint8_t* data + * @param keep_retrying while function executes to true socket will retry to + * write data if nonterminal error occurred on socket (EAGAIN, EWOULDBLOCK, + * EINTR)... useful if socket is in nonblocking mode or timeout is set on a + * socket. By default Write doesn't retry if any error occurrs. + * + * TODO: Logic for retrying can be in derived class or in a wrapper of this + * class, unfortunately from current return value we don't know what error + * occured nor how much data was written. * * @return write success status: * true if write succeeded * false if write failed */ - bool Write(const std::string &str); - bool Write(const char *data, size_t len); - bool Write(const uint8_t *data, size_t len); + bool Write(const uint8_t *data, size_t len, + const std::function<bool()> &keep_retrying = [] { return false; }); /** * Read data from the socket. @@ -157,6 +164,10 @@ class Socket { * == 0 if the client closed the connection * < 0 if an error has occurred */ + // TODO: Return type should be something like StatusOr<int> which would return + // number of read bytes if read succeeded and error code and error message + // otherwise (deduced from errno). We can implement that type easily on top of + // std::variant once c++17 becomes available in memgraph. int Read(void *buffer, size_t len); private: diff --git a/src/io/network/socket_event_dispatcher.hpp b/src/io/network/socket_event_dispatcher.hpp index 85f3d65cf..09627375a 100644 --- a/src/io/network/socket_event_dispatcher.hpp +++ b/src/io/network/socket_event_dispatcher.hpp @@ -43,14 +43,13 @@ class SocketEventDispatcher { // probably what we want to do. try { // Hangup event. - if (UNLIKELY(event.events & EPOLLRDHUP)) { + if (event.events & EPOLLRDHUP) { listener.OnClose(); continue; } // There was an error on the server side. - if (UNLIKELY(!(event.events & EPOLLIN) || - event.events & (EPOLLHUP | EPOLLERR))) { + if (!(event.events & EPOLLIN) || event.events & (EPOLLHUP | EPOLLERR)) { listener.OnError(); continue; } @@ -88,16 +87,11 @@ class BaseListener { void OnData() {} void OnException(const std::exception &e) { - // TODO: this actually sounds quite bad, maybe we should close socket here - // because we don'y know in which state Listener class is. - LOG(ERROR) << "Exception was thrown while processing event on socket " + LOG(FATAL) << "Exception was thrown while processing event on socket " << socket_.fd() << " with message: " << e.what(); } - void OnError() { - LOG(ERROR) << "Error on server side occured in epoll"; - socket_.Close(); - } + void OnError() { LOG(FATAL) << "Error on server side occured in epoll"; } protected: Socket &socket_; diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index d2065a637..3e2755056 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -21,15 +21,12 @@ #include "version.hpp" namespace fs = std::experimental::filesystem; -using endpoint_t = io::network::NetworkEndpoint; -using socket_t = io::network::Socket; -using session_t = communication::bolt::Session<socket_t>; -using result_stream_t = - communication::bolt::ResultStream<communication::bolt::Encoder< - communication::bolt::ChunkedEncoderBuffer<socket_t>>>; -using session_data_t = communication::bolt::SessionData<result_stream_t>; -using bolt_server_t = - communication::Server<session_t, socket_t, session_data_t>; +using io::network::NetworkEndpoint; +using io::network::Socket; +using communication::bolt::SessionData; +using SessionT = communication::bolt::Session<Socket>; +using ResultStreamT = SessionT::ResultStreamT; +using ServerT = communication::Server<SessionT, SessionData>; DEFINE_string(interface, "0.0.0.0", "Communication interface on which to listen."); @@ -37,8 +34,7 @@ DEFINE_string(port, "7687", "Communication port on which to listen."); DEFINE_VALIDATED_int32(num_workers, std::max(std::thread::hardware_concurrency(), 1U), "Number of workers", FLAG_IN_RANGE(1, INT32_MAX)); -DEFINE_string(log_file, "", - "Path to where the log should be stored."); +DEFINE_string(log_file, "", "Path to where the log should be stored."); DEFINE_string(log_link_basename, "", "Basename used for symlink creation to the last log file."); DEFINE_uint64(memory_warning_threshold, 1024, @@ -52,7 +48,7 @@ DEFINE_uint64(memory_warning_threshold, 1024, // 3) env - MEMGRAPH_CONFIG // 4) command line flags -void load_config(int &argc, char **&argv) { +void LoadConfig(int &argc, char **&argv) { std::vector<fs::path> configs = {fs::path("/etc/memgraph/memgraph.conf")}; if (getenv("HOME") != nullptr) configs.emplace_back(fs::path(getenv("HOME")) / @@ -99,7 +95,7 @@ int main(int argc, char **argv) { google::SetUsageMessage("Memgraph database server"); gflags::SetVersionString(version_string); - load_config(argc, argv); + LoadConfig(argc, argv); google::InitGoogleLogging(argv[0]); google::SetLogDestination(google::INFO, FLAGS_log_file.c_str()); @@ -123,31 +119,19 @@ int main(int argc, char **argv) { }); // Initialize bolt session data (Dbms and Interpreter). - session_data_t session_data; + SessionData session_data; // Initialize endpoint. - endpoint_t endpoint; - try { - endpoint = endpoint_t(FLAGS_interface, FLAGS_port); - } catch (io::network::NetworkEndpointException &e) { - LOG(FATAL) << e.what(); - } - - // Initialize socket. - socket_t socket; - if (!socket.Bind(endpoint)) { - LOG(FATAL) << "Cannot bind to socket on " << FLAGS_interface << " at " - << FLAGS_port; - } - if (!socket.SetNonBlocking()) { - LOG(FATAL) << "Cannot set socket to non blocking!"; - } - if (!socket.Listen(1024)) { - LOG(FATAL) << "Cannot listen on socket!"; - } + NetworkEndpoint endpoint = [&] { + try { + return NetworkEndpoint(FLAGS_interface, FLAGS_port); + } catch (io::network::NetworkEndpointException &e) { + LOG(FATAL) << e.what(); + } + }(); // Initialize server. - bolt_server_t server(std::move(socket), session_data); + ServerT server(endpoint, session_data); // register SIGTERM handler SignalHandler::register_handler(Signal::Terminate, diff --git a/tests/concurrent/network_common.hpp b/tests/concurrent/network_common.hpp index decde19b0..166067374 100644 --- a/tests/concurrent/network_common.hpp +++ b/tests/concurrent/network_common.hpp @@ -1,6 +1,7 @@ #pragma once #include <array> +#include <chrono> #include <cstring> #include <iostream> #include <random> @@ -11,24 +12,26 @@ #include "communication/bolt/v1/decoder/buffer.hpp" #include "communication/server.hpp" +#include "database/graph_db_accessor.hpp" #include "io/network/epoll.hpp" #include "io/network/socket.hpp" static constexpr const int SIZE = 60000; static constexpr const int REPLY = 10; -using endpoint_t = io::network::NetworkEndpoint; -using socket_t = io::network::Socket; +using io::network::NetworkEndpoint; +using io::network::Socket; class TestData {}; class TestSession { public: - TestSession(socket_t &&socket, TestData &) : socket_(std::move(socket)) { + TestSession(Socket &&socket, TestData &) : socket_(std::move(socket)) { event_.data.ptr = this; } - bool Alive() { return socket_.IsOpen(); } + bool Alive() const { return socket_.IsOpen(); } + bool TimedOut() const { return false; } int Id() const { return socket_.fd(); } @@ -56,15 +59,12 @@ class TestSession { } communication::bolt::Buffer<SIZE * 2> buffer_; - socket_t socket_; + Socket socket_; io::network::Epoll::Event event_; + std::chrono::time_point<std::chrono::steady_clock> last_event_time_; }; -using test_server_t = communication::Server<TestSession, socket_t, TestData>; - -void server_start(void *serverptr, int num) { - ((test_server_t *)serverptr)->Start(num); -} +using ServerT = communication::Server<TestSession, TestData>; void client_run(int num, const char *interface, const char *port, const unsigned char *data, int lo, int hi) { @@ -72,8 +72,8 @@ void client_run(int num, const char *interface, const char *port, name << "Client " << num; unsigned char buffer[SIZE * REPLY], head[2]; int have, read; - endpoint_t endpoint(interface, port); - socket_t socket; + NetworkEndpoint endpoint(interface, port); + Socket socket; ASSERT_TRUE(socket.Connect(endpoint)); ASSERT_TRUE(socket.SetTimeout(2, 0)); DLOG(INFO) << "Socket create: " << socket.fd(); diff --git a/tests/concurrent/network_read_hang.cpp b/tests/concurrent/network_read_hang.cpp index 9e5c37417..a13e2a9d8 100644 --- a/tests/concurrent/network_read_hang.cpp +++ b/tests/concurrent/network_read_hang.cpp @@ -14,23 +14,26 @@ #include "communication/bolt/v1/decoder/buffer.hpp" #include "communication/server.hpp" +#include "database/graph_db_accessor.hpp" #include "io/network/epoll.hpp" #include "io/network/socket.hpp" static constexpr const char interface[] = "127.0.0.1"; -using endpoint_t = io::network::NetworkEndpoint; -using socket_t = io::network::Socket; +using io::network::NetworkEndpoint; +using io::network::Socket; class TestData {}; class TestSession { public: - TestSession(socket_t &&socket, TestData &) : socket_(std::move(socket)) { + TestSession(Socket &&socket, TestData &) : socket_(std::move(socket)) { event_.data.ptr = this; } - bool Alive() { return socket_.IsOpen(); } + bool Alive() const { return socket_.IsOpen(); } + + bool TimedOut() const { return false; } int Id() const { return socket_.fd(); } @@ -42,19 +45,17 @@ class TestSession { void Close() { this->socket_.Close(); } - socket_t socket_; + Socket socket_; communication::bolt::Buffer<> buffer_; io::network::Epoll::Event event_; + std::chrono::time_point<std::chrono::steady_clock> last_event_time_; }; -using test_server_t = communication::Server<TestSession, socket_t, TestData>; - -test_server_t *serverptr; std::atomic<bool> run{true}; void client_run(int num, const char *interface, const char *port) { - endpoint_t endpoint(interface, port); - socket_t socket; + NetworkEndpoint endpoint(interface, port); + Socket socket; uint8_t data = 0x00; ASSERT_TRUE(socket.Connect(endpoint)); ASSERT_TRUE(socket.SetTimeout(1, 0)); @@ -67,32 +68,22 @@ void client_run(int num, const char *interface, const char *port) { socket.Close(); } -void server_run(void *serverptr, int num) { - ((test_server_t *)serverptr)->Start(num); -} - TEST(Network, SocketReadHangOnConcurrentConnections) { // initialize listen socket - endpoint_t endpoint(interface, "0"); - socket_t socket; - ASSERT_TRUE(socket.Bind(endpoint)); - ASSERT_TRUE(socket.SetNonBlocking()); - ASSERT_TRUE(socket.Listen(1024)); + NetworkEndpoint endpoint(interface, "0"); - // get bound address - auto ep = socket.endpoint(); - printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port()); + printf("ADDRESS: %s, PORT: %d\n", endpoint.address(), endpoint.port()); // initialize server TestData data; - test_server_t server(std::move(socket), data); - serverptr = &server; + communication::Server<TestSession, TestData> server(endpoint, data); // start server int N = (std::thread::hardware_concurrency() + 1) / 2; int Nc = N * 3; - std::thread server_thread(server_run, serverptr, N); + std::thread server_thread([&] { server.Start(N); }); + const auto &ep = server.endpoint(); // start clients std::vector<std::thread> clients; for (int i = 0; i < Nc; ++i) diff --git a/tests/concurrent/network_server.cpp b/tests/concurrent/network_server.cpp index 56c27f52b..342278193 100644 --- a/tests/concurrent/network_server.cpp +++ b/tests/concurrent/network_server.cpp @@ -8,32 +8,23 @@ static constexpr const char interface[] = "127.0.0.1"; unsigned char data[SIZE]; -test_server_t *serverptr; - TEST(Network, Server) { // initialize test data initialize_data(data, SIZE); // initialize listen socket - endpoint_t endpoint(interface, "0"); - socket_t socket; - ASSERT_TRUE(socket.Bind(endpoint)); - ASSERT_TRUE(socket.SetNonBlocking()); - ASSERT_TRUE(socket.Listen(1024)); - - // get bound address - auto ep = socket.endpoint(); - printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port()); + NetworkEndpoint endpoint(interface, "0"); + printf("ADDRESS: %s, PORT: %d\n", endpoint.address(), endpoint.port()); // initialize server TestData session_data; - test_server_t server(std::move(socket), session_data); - serverptr = &server; + ServerT server(endpoint, session_data); // start server int N = (std::thread::hardware_concurrency() + 1) / 2; - std::thread server_thread(server_start, serverptr, N); + std::thread server_thread([&] { server.Start(N); }); + const auto &ep = server.endpoint(); // start clients std::vector<std::thread> clients; for (int i = 0; i < N; ++i) diff --git a/tests/concurrent/network_session_leak.cpp b/tests/concurrent/network_session_leak.cpp index afa5c0bc7..9872ff804 100644 --- a/tests/concurrent/network_session_leak.cpp +++ b/tests/concurrent/network_session_leak.cpp @@ -10,8 +10,6 @@ static constexpr const char interface[] = "127.0.0.1"; unsigned char data[SIZE]; -test_server_t *serverptr; - using namespace std::chrono_literals; TEST(Network, SessionLeak) { @@ -19,28 +17,21 @@ TEST(Network, SessionLeak) { initialize_data(data, SIZE); // initialize listen socket - endpoint_t endpoint(interface, "0"); - socket_t socket; - ASSERT_TRUE(socket.Bind(endpoint)); - ASSERT_TRUE(socket.SetNonBlocking()); - ASSERT_TRUE(socket.Listen(1024)); - - // get bound address - auto ep = socket.endpoint(); - printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port()); + NetworkEndpoint endpoint(interface, "0"); + printf("ADDRESS: %s, PORT: %d\n", endpoint.address(), endpoint.port()); // initialize server TestData session_data; - test_server_t server(std::move(socket), session_data); - serverptr = &server; + ServerT server(endpoint, session_data); // start server - std::thread server_thread(server_start, serverptr, 2); + std::thread server_thread([&] { server.Start(2); }); // start clients int N = 50; std::vector<std::thread> clients; + const auto &ep = server.endpoint(); int testlen = 3000; for (int i = 0; i < N; ++i) { clients.push_back(std::thread(client_run, i, interface, ep.port_str(), data, diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index d5da81910..ed789a9b7 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -1,5 +1,6 @@ #include <array> #include <cstring> +#include <functional> #include <iostream> #include <random> #include <vector> @@ -13,26 +14,19 @@ */ class TestSocket { public: - TestSocket(int socket) : socket(socket) {} - TestSocket(const TestSocket &s) : socket(s.id()){}; - TestSocket(TestSocket &&other) { *this = std::forward<TestSocket>(other); } + explicit TestSocket(int socket) : socket_(socket) {} + TestSocket(const TestSocket &) = default; + TestSocket &operator=(const TestSocket &) = default; + TestSocket(TestSocket &&) = default; + TestSocket &operator=(TestSocket &&) = default; - TestSocket &operator=(TestSocket &&other) { - this->socket = other.socket; - other.socket = -1; - return *this; - } + void Close() { socket_ = -1; } + bool IsOpen() { return socket_ != -1; } - void Close() { socket = -1; } - bool IsOpen() { return socket != -1; } + int id() const { return socket_; } - int id() const { return socket; } - - bool Write(const std::string &str) { return Write(str.c_str(), str.size()); } - bool Write(const char *data, size_t len) { - return Write(reinterpret_cast<const uint8_t *>(data), len); - } - bool Write(const uint8_t *data, size_t len) { + bool Write(const uint8_t *data, size_t len, + const std::function<bool()> & = [] { return false; }) { if (!write_success_) return false; for (size_t i = 0; i < len; ++i) output.push_back(data[i]); return true; @@ -43,7 +37,7 @@ class TestSocket { std::vector<uint8_t> output; protected: - int socket; + int socket_; bool write_success_{true}; }; diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index e898005c2..ec4b4c18f 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -7,18 +7,16 @@ // TODO: This could be done in fixture. // Shortcuts for writing variable initializations in tests -#define INIT_VARS \ - TestSocket socket(10); \ - SessionDataT session_data; \ - SessionT session(std::move(socket), session_data); \ +#define INIT_VARS \ + TestSocket socket(10); \ + SessionData session_data; \ + SessionT session(std::move(socket), session_data); \ std::vector<uint8_t> &output = session.socket_.output; -using ResultStreamT = - communication::bolt::ResultStream<communication::bolt::Encoder< - communication::bolt::ChunkedEncoderBuffer<TestSocket>>>; -using SessionDataT = communication::bolt::SessionData<ResultStreamT>; +using communication::bolt::State; +using communication::bolt::SessionData; using SessionT = communication::bolt::Session<TestSocket>; -using StateT = communication::bolt::State; +using ResultStreamT = SessionT::ResultStreamT; // Sample testdata that has correct inputs and outputs. const uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, @@ -88,7 +86,7 @@ void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) { session.Written(20); session.Execute(); - ASSERT_EQ(session.state_, StateT::Init); + ASSERT_EQ(session.state_, State::Init); ASSERT_TRUE(session.socket_.IsOpen()); PrintOutput(output); CheckOutput(output, handshake_resp, 4); @@ -108,7 +106,7 @@ void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len, // Execute and check a correct init void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) { ExecuteCommand(session, init_req, sizeof(init_req)); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.socket_.IsOpen()); PrintOutput(output); CheckOutput(output, init_resp, 7); @@ -151,7 +149,7 @@ TEST(BoltSession, HandshakeWrongPreamble) { session.Written(20); session.Execute(); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); PrintOutput(output); CheckFailureMessage(output); @@ -165,14 +163,14 @@ TEST(BoltSession, HandshakeInTwoPackets) { session.Written(10); session.Execute(); - ASSERT_EQ(session.state_, StateT::Handshake); + ASSERT_EQ(session.state_, State::Handshake); ASSERT_TRUE(session.socket_.IsOpen()); memcpy(buff.data + 10, handshake_req + 10, 10); session.Written(10); session.Execute(); - ASSERT_EQ(session.state_, StateT::Init); + ASSERT_EQ(session.state_, State::Init); ASSERT_TRUE(session.socket_.IsOpen()); PrintOutput(output); CheckOutput(output, handshake_resp, 4); @@ -183,7 +181,7 @@ TEST(BoltSession, HandshakeWriteFail) { session.socket_.SetWriteSuccess(false); ExecuteCommand(session, handshake_req, sizeof(handshake_req), false); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -198,7 +196,7 @@ TEST(BoltSession, InitWrongSignature) { ExecuteHandshake(session, output); ExecuteCommand(session, run_req_header, sizeof(run_req_header)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -211,7 +209,7 @@ TEST(BoltSession, InitWrongMarker) { uint8_t data[2] = {0x00, init_req[1]}; ExecuteCommand(session, data, 2); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -226,7 +224,7 @@ TEST(BoltSession, InitMissingData) { ExecuteHandshake(session, output); ExecuteCommand(session, init_req, len[i]); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -238,7 +236,7 @@ TEST(BoltSession, InitWriteFail) { session.socket_.SetWriteSuccess(false); ExecuteCommand(session, init_req, sizeof(init_req)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -259,7 +257,7 @@ TEST(BoltSession, ExecuteRunWrongMarker) { uint8_t data[2] = {0x00, run_req_header[1]}; ExecuteCommand(session, data, sizeof(data)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -275,7 +273,7 @@ TEST(BoltSession, ExecuteRunMissingData) { ExecuteInit(session, output); ExecuteCommand(session, run_req_header, len[i]); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -294,11 +292,11 @@ TEST(BoltSession, ExecuteRunBasicException) { session.Execute(); if (i == 0) { - ASSERT_EQ(session.state_, StateT::ErrorIdle); + ASSERT_EQ(session.state_, State::ErrorIdle); ASSERT_TRUE(session.socket_.IsOpen()); CheckFailureMessage(output); } else { - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -314,7 +312,7 @@ TEST(BoltSession, ExecuteRunWithoutPullAll) { WriteRunRequest(session, "RETURN 2"); session.Execute(); - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); } TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) { @@ -332,7 +330,7 @@ TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) { uint8_t data[2] = {0x00, dataset[i][1]}; ExecuteCommand(session, data, sizeof(data)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -349,7 +347,7 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) { session.socket_.SetWriteSuccess(i == 0); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); if (i == 0) { CheckFailureMessage(output); @@ -380,12 +378,12 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) { ExecuteCommand(session, dataset[i], 2); if (j == 0) { - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.socket_.IsOpen()); ASSERT_FALSE(session.encoder_buffer_.HasData()); PrintOutput(output); } else { - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -400,7 +398,7 @@ TEST(BoltSession, ExecuteInvalidMessage) { ExecuteInit(session, output); ExecuteCommand(session, init_req, sizeof(init_req)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -425,11 +423,11 @@ TEST(BoltSession, ErrorIgnoreMessage) { ASSERT_EQ(session.decoder_buffer_.Size(), 0); if (i == 0) { - ASSERT_EQ(session.state_, StateT::ErrorIdle); + ASSERT_EQ(session.state_, State::ErrorIdle); ASSERT_TRUE(session.socket_.IsOpen()); CheckOutput(output, ignored_resp, sizeof(ignored_resp)); } else { - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -451,13 +449,13 @@ TEST(BoltSession, ErrorRunAfterRun) { session.socket_.SetWriteSuccess(true); // Session holds results of last run. - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); // New run request. WriteRunRequest(session, "MATCH (n) RETURN n"); session.Execute(); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); } @@ -475,7 +473,7 @@ TEST(BoltSession, ErrorCantCleanup) { // there is data missing in the request, cleanup should fail ExecuteCommand(session, init_req, sizeof(init_req) - 10); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -495,7 +493,7 @@ TEST(BoltSession, ErrorWrongMarker) { uint8_t data[2] = {0x00, init_req[1]}; ExecuteCommand(session, data, sizeof(data)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -524,11 +522,11 @@ TEST(BoltSession, ErrorOK) { ASSERT_EQ(session.decoder_buffer_.Size(), 0); if (j == 0) { - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.socket_.IsOpen()); CheckOutput(output, success_resp, sizeof(success_resp)); } else { - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } @@ -551,7 +549,7 @@ TEST(BoltSession, ErrorMissingData) { uint8_t data[1] = {0x00}; ExecuteCommand(session, data, sizeof(data)); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); } @@ -565,7 +563,7 @@ TEST(BoltSession, MultipleChunksInOneExecute) { WriteRunRequest(session, "CREATE (n) RETURN n"); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.socket_.IsOpen()); PrintOutput(output); @@ -597,7 +595,7 @@ TEST(BoltSession, PartialChunk) { // missing chunk tail session.Execute(); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); @@ -605,7 +603,7 @@ TEST(BoltSession, PartialChunk) { session.Execute(); - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_GT(output.size(), 0); PrintOutput(output); @@ -624,25 +622,25 @@ TEST(BoltSession, ExplicitTransactionValidQueries) { WriteRunRequest(session, "BEGIN"); session.Execute(); - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); WriteRunRequest(session, "MATCH (n) RETURN n"); session.Execute(); - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); @@ -650,11 +648,11 @@ TEST(BoltSession, ExplicitTransactionValidQueries) { session.Execute(); ASSERT_FALSE(session.db_accessor_); CheckSuccessMessage(output); - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_FALSE(session.db_accessor_); CheckSuccessMessage(output); @@ -673,31 +671,31 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) { WriteRunRequest(session, "BEGIN"); session.Execute(); - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); WriteRunRequest(session, "MATCH ("); session.Execute(); - ASSERT_EQ(session.state_, StateT::ErrorWaitForRollback); + ASSERT_EQ(session.state_, State::ErrorWaitForRollback); ASSERT_TRUE(session.db_accessor_); CheckFailureMessage(output); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::ErrorWaitForRollback); + ASSERT_EQ(session.state_, State::ErrorWaitForRollback); ASSERT_TRUE(session.db_accessor_); CheckIgnoreMessage(output); ExecuteCommand(session, ackfailure_req, sizeof(ackfailure_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::WaitForRollback); + ASSERT_EQ(session.state_, State::WaitForRollback); ASSERT_TRUE(session.db_accessor_); CheckSuccessMessage(output); @@ -705,20 +703,20 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) { session.Execute(); if (transaction_end == "ROLLBACK") { - ASSERT_EQ(session.state_, StateT::Result); + ASSERT_EQ(session.state_, State::Result); ASSERT_FALSE(session.db_accessor_); ASSERT_TRUE(session.socket_.IsOpen()); CheckSuccessMessage(output); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); session.Execute(); - ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_EQ(session.state_, State::Idle); ASSERT_FALSE(session.db_accessor_); ASSERT_TRUE(session.socket_.IsOpen()); CheckSuccessMessage(output); } else { - ASSERT_EQ(session.state_, StateT::Close); + ASSERT_EQ(session.state_, State::Close); ASSERT_FALSE(session.db_accessor_); ASSERT_FALSE(session.socket_.IsOpen()); CheckFailureMessage(output); diff --git a/tests/unit/network_timeouts.cpp b/tests/unit/network_timeouts.cpp new file mode 100644 index 000000000..ceb87464d --- /dev/null +++ b/tests/unit/network_timeouts.cpp @@ -0,0 +1,104 @@ +#include <chrono> +#include <experimental/filesystem> +#include <iostream> + +#include <gflags/gflags.h> +#include <glog/logging.h> +#include "gtest/gtest.h" + +#include "communication/bolt/client.hpp" +#include "communication/bolt/v1/session.hpp" +#include "communication/server.hpp" +#include "io/network/network_endpoint.hpp" +#include "io/network/socket.hpp" + +DECLARE_int32(query_execution_time_sec); +DECLARE_int32(session_inactivity_timeout); + +using namespace std::chrono_literals; +class TestClientSocket; +using io::network::NetworkEndpoint; +using io::network::Socket; +using communication::bolt::SessionData; +using communication::bolt::ClientException; +using SessionT = communication::bolt::Session<Socket>; +using ResultStreamT = SessionT::ResultStreamT; +using ServerT = communication::Server<SessionT, SessionData>; +using ClientT = communication::bolt::Client<Socket>; + +class RunningServer { + public: + ~RunningServer() { + server_.Shutdown(); + server_thread_.join(); + } + + SessionData session_data_; + NetworkEndpoint endpoint_{"127.0.0.1", "0"}; + ServerT server_{endpoint_, session_data_}; + std::thread server_thread_{[&] { server_.Start(1); }}; +}; + +class TestClient : public ClientT { + public: + TestClient(NetworkEndpoint endpoint) + : ClientT( + [&] { + Socket socket; + socket.Connect(endpoint); + return socket; + }(), + "", "") {} +}; + +TEST(NetworkTimeouts, InactiveSession) { + FLAGS_query_execution_time_sec = 60; + FLAGS_session_inactivity_timeout = 1; + RunningServer rs; + + TestClient client(rs.server_.endpoint()); + // Check that we can execute first query. + client.Execute("RETURN 1", {}); + + // After sleep, session should still be alive. + std::this_thread::sleep_for(500ms); + client.Execute("RETURN 1", {}); + + // After sleep, session should still be alive. + std::this_thread::sleep_for(500ms); + client.Execute("RETURN 1", {}); + + // After sleep, session should still be alive. + std::this_thread::sleep_for(500ms); + client.Execute("RETURN 1", {}); + + // After sleep, session should have timed out. + std::this_thread::sleep_for(1500ms); + EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException); +} + +TEST(NetworkTimeouts, TimeoutInMultiCommandTransaction) { + FLAGS_query_execution_time_sec = 1; + FLAGS_session_inactivity_timeout = 60; + RunningServer rs; + + TestClient client(rs.server_.endpoint()); + + // Start explicit multicommand transaction. + client.Execute("BEGIN", {}); + client.Execute("RETURN 1", {}); + + // Session should still be alive. + std::this_thread::sleep_for(500ms); + client.Execute("RETURN 1", {}); + + // Session shouldn't be alive anymore. + std::this_thread::sleep_for(2s); + EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + google::InitGoogleLogging(argv[0]); + return RUN_ALL_TESTS(); +}