Close session on timeouts

Reviewers: mferencevic

Reviewed By: mferencevic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D897
This commit is contained in:
Mislav Bradac 2017-10-17 14:05:08 +02:00
parent 5525af558b
commit 19a44a7d94
24 changed files with 468 additions and 324 deletions

View File

@ -31,7 +31,8 @@ BASE_FLAGS = [
'-I./libs/gflags/include',
'-I./experimental/distributed/src',
'-I./experimental/distributed/libs/cereal/include',
'-I./libs/postgresql/include'
'-I./libs/postgresql/include',
'-I./build/include'
]
SOURCE_EXTENSIONS = [

View File

@ -180,6 +180,7 @@ target_link_libraries(antlr_opencypher_parser_lib antlr4)
# all memgraph src files
set(memgraph_src_files
${src_dir}/communication/bolt/v1/decoder/decoded_value.cpp
${src_dir}/communication/bolt/v1/session.cpp
${src_dir}/data_structures/concurrent/skiplist_gc.cpp
${src_dir}/database/graph_db.cpp
${src_dir}/database/graph_db_accessor.cpp

View File

@ -116,7 +116,6 @@ class ChunkedDecoderBuffer {
return ChunkState::Partial;
}
data_.reserve(data_.size() + chunk_size);
std::copy(data + 2, data + chunk_size + 2, std::back_inserter(data_));
buffer_.Shift(chunk_size + 2);

View File

@ -92,7 +92,6 @@ class ChunkedEncoderBuffer {
// 3. Copy whole chunk into the buffer.
size_ += pos_;
buffer_.reserve(size_);
std::copy(chunk_.begin(), chunk_.begin() + pos_,
std::back_inserter(buffer_));

View File

@ -0,0 +1,7 @@
#include "communication/bolt/v1/session.hpp"
// Danger: If multiple sessions are associated with one worker nonactive
// sessions could become blocked for FLAGS_session_inactivity_timeout time.
// TODO: We should never associate more sessions with one worker.
DEFINE_int32(session_inactivity_timeout, 1800,
"Time in seconds after which inactive sessions will be closed");

View File

@ -4,24 +4,24 @@
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
#include "database/dbms.hpp"
#include "query/interpreter.hpp"
#include "transactions/transaction.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp"
#include "communication/bolt/v1/decoder/decoder.hpp"
#include "communication/bolt/v1/encoder/encoder.hpp"
#include "communication/bolt/v1/encoder/result_stream.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/states/error.hpp"
#include "communication/bolt/v1/states/executing.hpp"
#include "communication/bolt/v1/states/handshake.hpp"
#include "communication/bolt/v1/states/init.hpp"
#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp"
#include "communication/bolt/v1/decoder/decoder.hpp"
#include "communication/bolt/v1/encoder/encoder.hpp"
#include "communication/bolt/v1/encoder/result_stream.hpp"
#include "io/network/stream_buffer.hpp"
DECLARE_int32(session_inactivity_timeout);
namespace communication::bolt {
@ -30,11 +30,7 @@ namespace communication::bolt {
*
* This class is responsible for holding references to Dbms and Interpreter
* that are passed through the network server and worker to the session.
*
* @tparam OutputStream type of output stream (could be a bolt output stream or
* a test output stream)
*/
template <typename OutputStream>
struct SessionData {
Dbms dbms;
query::Interpreter interpreter;
@ -45,32 +41,40 @@ struct SessionData {
*
* This class is responsible for handling a single client connection.
*
* @tparam Socket type of socket (could be a network socket or test socket)
* @tparam TSocket type of socket (could be a network socket or test socket)
*/
template <typename Socket>
template <typename TSocket>
class Session {
private:
using OutputStream = ResultStream<Encoder<ChunkedEncoderBuffer<Socket>>>;
public:
// Wrapper around socket that checks if session has timed out on write
// failures, used in encoder buffer.
class TimeoutSocket {
public:
explicit TimeoutSocket(Session &session) : session_(session) {}
bool Write(const uint8_t *data, size_t len) {
return session_.socket_.Write(data, len,
[this] { return !session_.TimedOut(); });
}
private:
Session &session_;
};
using ResultStreamT =
ResultStream<Encoder<ChunkedEncoderBuffer<TimeoutSocket>>>;
using StreamBuffer = io::network::StreamBuffer;
public:
Session(Socket &&socket, SessionData<OutputStream> &data)
Session(TSocket &&socket, SessionData &data)
: socket_(std::move(socket)),
dbms_(data.dbms),
interpreter_(data.interpreter) {
event_.data.ptr = this;
}
interpreter_(data.interpreter) {}
~Session() {
DCHECK(!db_accessor_)
<< "Transaction should have already be closed in Close";
}
/**
* @return is the session in a valid state
*/
bool Alive() const { return state_ != State::Close; }
/**
* @return the socket id
*/
@ -154,6 +158,19 @@ class Session {
*/
void Written(size_t len) { buffer_.Written(len); }
/**
* Returns true if session has timed out. Session times out if there was no
* activity in FLAGS_sessions_inactivity_timeout seconds or if there is a
* active transaction with shoul_abort flag set to true.
*/
bool TimedOut() const {
return db_accessor_
? db_accessor_->should_abort()
: last_event_time_ + std::chrono::seconds(
FLAGS_session_inactivity_timeout) <
std::chrono::steady_clock::now();
}
/**
* Closes the session (client socket).
*/
@ -183,37 +200,40 @@ class Session {
db_accessor_ = nullptr;
}
// TODO: Rethink if there is a way to hide some members. At the momemnt all of
// them are public.
Socket socket_;
// TODO: Rethink if there is a way to hide some members. At the momement all
// of them are public.
TSocket socket_;
Dbms &dbms_;
query::Interpreter &interpreter_;
ChunkedEncoderBuffer<Socket> encoder_buffer_{socket_};
Encoder<ChunkedEncoderBuffer<Socket>> encoder_{encoder_buffer_};
OutputStream output_stream_{encoder_};
TimeoutSocket timeout_socket_{*this};
ChunkedEncoderBuffer<TimeoutSocket> encoder_buffer_{timeout_socket_};
Encoder<ChunkedEncoderBuffer<TimeoutSocket>> encoder_{encoder_buffer_};
ResultStreamT output_stream_{encoder_};
Buffer<> buffer_;
ChunkedDecoderBuffer decoder_buffer_{buffer_};
Decoder<ChunkedDecoderBuffer> decoder_{decoder_buffer_};
io::network::Epoll::Event event_;
bool handshake_done_{false};
State state_{State::Handshake};
// GraphDbAccessor of active transaction in the session, can be null if there
// is no associated transaction.
// GraphDbAccessor of active transaction in the session, can be null if
// there is no associated transaction.
std::unique_ptr<GraphDbAccessor> db_accessor_;
// Time of the last event.
std::chrono::time_point<std::chrono::steady_clock> last_event_time_ =
std::chrono::steady_clock::now();
private:
void ClientFailureInvalidData() {
// set the state to Close
// Set the state to Close.
state_ = State::Close;
// don't care about the return status because this is always
// called when we are about to close the connection to the client
// We don't care about the return status because this is called when we
// are about to close the connection to the client.
encoder_buffer_.Clear();
encoder_.MessageFailure({{"code", "Memgraph.InvalidData"},
{"message", "The client has sent invalid data!"}});
// close the connection
// Close the connection.
Close();
}
};

View File

@ -16,8 +16,8 @@ namespace communication::bolt {
* The error state is exited upon receiving an ACK_FAILURE or RESET message.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateErrorRun(Session &session, State state) {
template <typename TSession>
State StateErrorRun(TSession &session, State state) {
Marker marker;
Signature signature;
if (!session.decoder_.ReadMessageHeader(&signature, &marker)) {
@ -60,16 +60,16 @@ State StateErrorRun(Session &session, State state) {
} else {
uint8_t value = underlying_cast(marker);
// all bolt client messages have less than 15 parameters
// so if we receive anything than a TinyStruct it's an error
// All bolt client messages have less than 15 parameters so if we receive
// anything than a TinyStruct it's an error.
if ((value & 0xF0) != underlying_cast(Marker::TinyStruct)) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct marker, but received 0x{:02X}!", value);
return State::Close;
}
// we need to clean up all parameters from this command
value &= 0x0F; // the length is stored in the lower nibble
// We need to clean up all parameters from this command.
value &= 0x0F; // The length is stored in the lower nibble.
DecodedValue dv;
for (int i = 0; i < value; ++i) {
if (!session.decoder_.ReadValue(&dv)) {
@ -79,13 +79,13 @@ State StateErrorRun(Session &session, State state) {
}
}
// ignore this message
// Ignore this message.
if (!session.encoder_.MessageIgnored()) {
DLOG(WARNING) << "Couldn't send ignored message!";
return State::Close;
}
// cleanup done, command ignored, stay in error state
// Cleanup done, command ignored, stay in error state.
return state;
}
}

View File

@ -15,8 +15,8 @@
namespace communication::bolt {
template <typename Session>
State HandleRun(Session &session, State state, Marker marker) {
template <typename TSession>
State HandleRun(TSession &session, State state, Marker marker) {
const std::map<std::string, query::TypedValue> kEmptyFields = {
{"fields", std::vector<query::TypedValue>{}}};

View File

@ -2,6 +2,7 @@
#include <glog/logging.h>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/state.hpp"
@ -12,8 +13,8 @@ namespace communication::bolt {
* This function runs everything to make a Bolt handshake with the client.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateHandshakeRun(Session &session) {
template <typename TSession>
State StateHandshakeRun(TSession &session) {
auto precmp = memcmp(session.buffer_.data(), kPreamble, sizeof(kPreamble));
if (UNLIKELY(precmp != 0)) {
DLOG(WARNING) << "Received a wrong preamble!";
@ -21,10 +22,10 @@ State StateHandshakeRun(Session &session) {
}
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers
// this will change in the future
// make sense to check which version the client prefers this will change in
// the future.
if (!session.socket_.Write(kProtocol, sizeof(kProtocol))) {
if (!session.timeout_socket_.Write(kProtocol, sizeof(kProtocol))) {
DLOG(WARNING) << "Couldn't write handshake response!";
return State::Close;
}

View File

@ -12,9 +12,9 @@
namespace communication::bolt {
/**
* Init state run function
* Init state run function.
* This function runs everything to initialize a Bolt session with the client.
* @param session the session that should be used for the run
* @param session the session that should be used for the run.
*/
template <typename Session>
State StateInitRun(Session &session) {

View File

@ -11,6 +11,7 @@
#include <glog/logging.h>
#include "communication/worker.hpp"
#include "io/network/socket.hpp"
#include "io/network/socket_event_dispatcher.hpp"
namespace communication {
@ -27,30 +28,44 @@ namespace communication {
* Current Server achitecture:
* incomming connection -> server -> worker -> session
*
* @tparam Session the server can handle different Sessions, each session
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam Socket the input/output socket that should be used
* @tparam SessionData the class with objects that will be forwarded to the
* @tparam TSessionData the class with objects that will be forwarded to the
* session
*/
// TODO: Remove Socket templatisation. Socket requirements are very specific.
// It needs to be in non blocking mode, etc.
template <typename Session, typename Socket, typename SessionData>
template <typename TSession, typename TSessionData>
class Server {
public:
using worker_t = Worker<Session, Socket, SessionData>;
using WorkerT = Worker<TSession, TSessionData>;
using Socket = io::network::Socket;
Server(Socket &&socket, SessionData &session_data)
: socket_(std::move(socket)), session_data_(session_data) {}
Server(const io::network::NetworkEndpoint &endpoint,
TSessionData &session_data)
: session_data_(session_data) {
// Without server we can't continue with application so we can just
// terminate here.
if (!socket_.Bind(endpoint)) {
LOG(FATAL) << "Cannot bind to socket on " << endpoint.address() << " at "
<< endpoint.port();
}
if (!socket_.SetNonBlocking()) {
LOG(FATAL) << "Cannot set socket to non blocking!";
}
if (!socket_.Listen(1024)) {
LOG(FATAL) << "Cannot listen on socket!";
}
}
const auto &endpoint() const { return socket_.endpoint(); }
void Start(size_t n) {
std::cout << fmt::format("Starting {} workers", n) << std::endl;
workers_.reserve(n);
for (size_t i = 0; i < n; ++i) {
workers_.push_back(std::make_unique<worker_t>(session_data_));
workers_.push_back(std::make_unique<WorkerT>(session_data_));
worker_threads_.emplace_back(
[this](worker_t &worker) -> void { worker.Start(alive_); },
[this](WorkerT &worker) -> void { worker.Start(alive_); },
std::ref(*workers_.back()));
}
std::cout << "Server is fully armed and operational" << std::endl;
@ -81,15 +96,14 @@ class Server {
private:
class ConnectionAcceptor : public io::network::BaseListener {
public:
ConnectionAcceptor(Socket &socket,
Server<Session, Socket, SessionData> &server)
ConnectionAcceptor(Socket &socket, Server<TSession, TSessionData> &server)
: io::network::BaseListener(socket), server_(server) {}
void OnData() {
DCHECK(server_.idx_ < server_.workers_.size()) << "Invalid worker id.";
DLOG(INFO) << "On connect";
auto connection = AcceptConnection();
if (UNLIKELY(!connection)) {
if (!connection) {
// Connection is not available anymore or configuration failed.
return;
}
@ -112,21 +126,22 @@ class Server {
s->fd(), s->endpoint().address(), s->endpoint().family(),
s->endpoint().port());
if (!s->SetTimeout(1, 0)) return std::experimental::nullopt;
if (!s->SetKeepAlive()) return std::experimental::nullopt;
if (!s->SetNoDelay()) return std::experimental::nullopt;
return s;
}
Server<Session, Socket, SessionData> &server_;
Server<TSession, TSessionData> &server_;
};
std::vector<std::unique_ptr<worker_t>> workers_;
std::vector<std::unique_ptr<WorkerT>> workers_;
std::vector<std::thread> worker_threads_;
std::atomic<bool> alive_{true};
int idx_{0};
Socket socket_;
SessionData &session_data_;
TSessionData &session_data_;
};
} // namespace communication

View File

@ -2,6 +2,7 @@
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cstdio>
#include <iomanip>
#include <memory>
@ -9,9 +10,11 @@
#include <sstream>
#include <thread>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "io/network/network_error.hpp"
#include "io/network/socket.hpp"
#include "io/network/socket_event_dispatcher.hpp"
#include "io/network/stream_buffer.hpp"
#include "threading/sync/spinlock.hpp"
@ -27,20 +30,19 @@ namespace communication {
* Listens for incomming data on connections and accepts new connections.
* Also, executes sessions on incomming data.
*
* @tparam Session the worker can handle different Sessions, each session
* @tparam TSession the worker can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam Socket the input/output socket that should be used
* @tparam SessionData the class with objects that will be forwarded to the
* @tparam TSessionData the class with objects that will be forwarded to the
* session
*/
template <typename Session, typename Socket, typename SessionData>
template <typename TSession, typename TSessionData>
class Worker {
using StreamBuffer = io::network::StreamBuffer;
using Socket = io::network::Socket;
public:
void AddConnection(Socket &&connection) {
std::unique_lock<SpinLock> gurad(lock_);
std::unique_lock<SpinLock> guard(lock_);
// Remember fd before moving connection into SessionListener.
int fd = connection.fd();
session_listeners_.push_back(
@ -51,11 +53,31 @@ class Worker {
EPOLLIN | EPOLLRDHUP);
}
Worker(SessionData &session_data) : session_data_(session_data) {}
Worker(TSessionData &session_data) : session_data_(session_data) {}
void Start(std::atomic<bool> &alive) {
while (alive) {
dispatcher_.WaitAndProcessEvents();
bool check_sessions_for_timeouts = true;
while (check_sessions_for_timeouts) {
check_sessions_for_timeouts = false;
std::unique_lock<SpinLock> guard(lock_);
for (auto &session_listener : session_listeners_) {
if (session_listener->session().TimedOut()) {
// We need to unlock here, because OnSessionAndTxTimeout will need
// to acquire same lock.
guard.unlock();
session_listener->OnSessionAndTxTimeout();
// Since we released lock we can't continue iteration so we need to
// break. There could still be more sessions that timed out so we
// set check_sessions_for_timeout back to true.
check_sessions_for_timeouts = true;
break;
}
}
}
}
}
@ -65,76 +87,96 @@ class Worker {
class SessionSocketListener : public io::network::BaseListener {
public:
SessionSocketListener(Socket &&socket,
Worker<Session, Socket, SessionData> &worker)
Worker<TSession, TSessionData> &worker)
: BaseListener(session_.socket_),
session_(std::move(socket), worker.session_data_),
worker_(worker) {}
void OnError() {
LOG(ERROR) << "Error occured in this session";
OnClose();
}
auto &session() { return session_; }
const auto &session() const { return session_; }
const auto &TimedOut() const { return session_.TimedOut(); }
void OnData() {
session_.last_event_time_ = std::chrono::steady_clock::now();
DLOG(INFO) << "On data";
if (UNLIKELY(!session_.Alive())) {
DLOG(WARNING) << "Calling OnClose because the stream isn't alive!";
OnClose();
return;
}
// allocate the buffer to fill the data
auto buf = session_.Allocate();
// read from the buffer at most buf.len bytes
int len = session_.socket_.Read(buf.data, buf.len);
// check for read errors
if (len == -1) {
// this means we have read all available data
if (LIKELY(errno == EAGAIN || errno == EWOULDBLOCK)) {
// This means read would block or read was interrupted by signal.
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
return;
}
// some other error occurred, check errno
// Some other error occurred, check errno.
OnError();
return;
}
// end of file, the client has closed the connection
if (UNLIKELY(len == 0)) {
// The client has closed the connection.
if (len == 0) {
DLOG(WARNING) << "Calling OnClose because the socket is closed!";
OnClose();
return;
}
// notify the stream that it has new data
// Notify the stream that it has new data.
session_.Written(len);
DLOG(INFO) << "OnRead";
try {
session_.Execute();
} catch (const std::exception &e) {
LOG(ERROR) << "Error occured while executing statement. " << std::endl
LOG(ERROR) << "Error occured while executing statement with message: "
<< e.what();
// TODO: report to client
OnError();
}
// TODO: Should we even continue with this session if error occurs while
// reading.
session_.last_event_time_ = std::chrono::steady_clock::now();
}
// TODO: Remove duplication in next three functions.
void OnError() {
LOG(ERROR) << fmt::format(
"Error occured in session associated with {}:{}",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port());
CloseSession();
}
void OnException(const std::exception &e) {
LOG(ERROR) << fmt::format(
"Exception was thrown while processing event in session associated "
"with {}:{} with message: {}",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port(), e.what());
CloseSession();
}
void OnSessionAndTxTimeout() {
LOG(WARNING) << fmt::format(
"Session or transaction associated with {}:{} timed out.",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port());
// TODO: report to client what happend.
CloseSession();
}
void OnClose() {
LOG(INFO) << fmt::format("Client {}:{} closed the connection.",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port())
<< std::endl;
session_.socket_.endpoint().port());
CloseSession();
}
private:
void CloseSession() {
session_.Close();
std::unique_lock<SpinLock> gurad(worker_.lock_);
std::unique_lock<SpinLock> guard(worker_.lock_);
auto it = std::find_if(
worker_.session_listeners_.begin(), worker_.session_listeners_.end(),
[&](const auto &l) { return l->session_.Id() == session_.Id(); });
CHECK(it != worker_.session_listeners_.end())
<< "Trying to remove session that is not found in worker's sessions";
int i = it - worker_.session_listeners_.begin();
@ -142,13 +184,12 @@ class Worker {
worker_.session_listeners_.pop_back();
}
private:
Session session_;
TSession session_;
Worker &worker_;
};
SpinLock lock_;
SessionData &session_data_;
TSessionData &session_data_;
io::network::SocketEventDispatcher<SessionSocketListener> dispatcher_;
std::vector<std::unique_ptr<SessionSocketListener>> session_listeners_;
};

View File

@ -21,8 +21,7 @@ class Epoll {
public:
using Event = struct epoll_event;
Epoll(int flags) {
epoll_fd_ = epoll_create1(flags);
Epoll(int flags) : epoll_fd_(epoll_create1(flags)) {
// epoll_create1 returns an error if there is a logical error in our code
// (for example invalid flags) or if there is irrecoverable error. In both
// cases it is best to terminate.
@ -53,6 +52,6 @@ class Epoll {
}
private:
int epoll_fd_;
const int epoll_fd_;
};
}

View File

@ -17,6 +17,7 @@
#include <unistd.h>
#include "io/network/addrinfo.hpp"
#include "threading/sync/cpu_relax.hpp"
#include "utils/likely.hpp"
namespace io::network {
@ -47,7 +48,7 @@ void Socket::Close() {
socket_ = -1;
}
bool Socket::IsOpen() { return socket_ != -1; }
bool Socket::IsOpen() const { return socket_ != -1; }
bool Socket::Connect(const NetworkEndpoint &endpoint) {
if (socket_ != -1) return false;
@ -194,25 +195,27 @@ std::experimental::optional<Socket> Socket::Accept() {
return Socket(sfd, endpoint);
}
const NetworkEndpoint &Socket::endpoint() const { return endpoint_; }
bool Socket::Write(const std::string &str) {
return Write(str.c_str(), str.size());
}
bool Socket::Write(const char *data, size_t len) {
return Write(reinterpret_cast<const uint8_t *>(data), len);
}
bool Socket::Write(const uint8_t *data, size_t len) {
bool Socket::Write(const uint8_t *data, size_t len,
const std::function<bool()> &keep_retrying) {
while (len > 0) {
// MSG_NOSIGNAL is here to disable raising a SIGPIPE
// signal when a connection dies mid-write, the socket
// will only return an EPIPE error
// MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a
// connection dies mid-write, the socket will only return an EPIPE error.
auto written = send(socket_, data, len, MSG_NOSIGNAL);
if (UNLIKELY(written == -1)) return false;
len -= written;
data += written;
if (written == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
// Terminal error, return failure.
return false;
}
// TODO: This can still cause timed out session to continue for a very
// long time. For example if timeout on send is 1 second and after every
// sencond we succeed in writing only one byte that this function can
// block for len seconds. Change semantics of keep_retrying function so
// that this check can be done in while loop even if send succeeds.
if (!keep_retrying()) return false;
} else {
len -= written;
data += written;
}
}
return true;
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <experimental/optional>
#include <functional>
#include <iostream>
#include "io/network/network_endpoint.hpp"
@ -34,7 +35,7 @@ class Socket {
* true if the socket is open
* false if the socket is closed
*/
bool IsOpen();
bool IsOpen() const;
/**
* Connects the socket to the specified endpoint.
@ -127,23 +128,29 @@ class Socket {
/**
* Returns the currently active endpoint of the socket.
*/
const NetworkEndpoint &endpoint() const;
const NetworkEndpoint &endpoint() const { return endpoint_; }
/**
* Write data to the socket.
* Theese functions guarantee that all data will be written.
*
* @param str std::string to write to the socket
* @param data char* or uint8_t* to data that should be written
* @param data uint8_t* to data that should be written
* @param len length of char* or uint8_t* data
* @param keep_retrying while function executes to true socket will retry to
* write data if nonterminal error occurred on socket (EAGAIN, EWOULDBLOCK,
* EINTR)... useful if socket is in nonblocking mode or timeout is set on a
* socket. By default Write doesn't retry if any error occurrs.
*
* TODO: Logic for retrying can be in derived class or in a wrapper of this
* class, unfortunately from current return value we don't know what error
* occured nor how much data was written.
*
* @return write success status:
* true if write succeeded
* false if write failed
*/
bool Write(const std::string &str);
bool Write(const char *data, size_t len);
bool Write(const uint8_t *data, size_t len);
bool Write(const uint8_t *data, size_t len,
const std::function<bool()> &keep_retrying = [] { return false; });
/**
* Read data from the socket.
@ -157,6 +164,10 @@ class Socket {
* == 0 if the client closed the connection
* < 0 if an error has occurred
*/
// TODO: Return type should be something like StatusOr<int> which would return
// number of read bytes if read succeeded and error code and error message
// otherwise (deduced from errno). We can implement that type easily on top of
// std::variant once c++17 becomes available in memgraph.
int Read(void *buffer, size_t len);
private:

View File

@ -43,14 +43,13 @@ class SocketEventDispatcher {
// probably what we want to do.
try {
// Hangup event.
if (UNLIKELY(event.events & EPOLLRDHUP)) {
if (event.events & EPOLLRDHUP) {
listener.OnClose();
continue;
}
// There was an error on the server side.
if (UNLIKELY(!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR))) {
if (!(event.events & EPOLLIN) || event.events & (EPOLLHUP | EPOLLERR)) {
listener.OnError();
continue;
}
@ -88,16 +87,11 @@ class BaseListener {
void OnData() {}
void OnException(const std::exception &e) {
// TODO: this actually sounds quite bad, maybe we should close socket here
// because we don'y know in which state Listener class is.
LOG(ERROR) << "Exception was thrown while processing event on socket "
LOG(FATAL) << "Exception was thrown while processing event on socket "
<< socket_.fd() << " with message: " << e.what();
}
void OnError() {
LOG(ERROR) << "Error on server side occured in epoll";
socket_.Close();
}
void OnError() { LOG(FATAL) << "Error on server side occured in epoll"; }
protected:
Socket &socket_;

View File

@ -21,15 +21,12 @@
#include "version.hpp"
namespace fs = std::experimental::filesystem;
using endpoint_t = io::network::NetworkEndpoint;
using socket_t = io::network::Socket;
using session_t = communication::bolt::Session<socket_t>;
using result_stream_t =
communication::bolt::ResultStream<communication::bolt::Encoder<
communication::bolt::ChunkedEncoderBuffer<socket_t>>>;
using session_data_t = communication::bolt::SessionData<result_stream_t>;
using bolt_server_t =
communication::Server<session_t, socket_t, session_data_t>;
using io::network::NetworkEndpoint;
using io::network::Socket;
using communication::bolt::SessionData;
using SessionT = communication::bolt::Session<Socket>;
using ResultStreamT = SessionT::ResultStreamT;
using ServerT = communication::Server<SessionT, SessionData>;
DEFINE_string(interface, "0.0.0.0",
"Communication interface on which to listen.");
@ -37,8 +34,7 @@ DEFINE_string(port, "7687", "Communication port on which to listen.");
DEFINE_VALIDATED_int32(num_workers,
std::max(std::thread::hardware_concurrency(), 1U),
"Number of workers", FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(log_file, "",
"Path to where the log should be stored.");
DEFINE_string(log_file, "", "Path to where the log should be stored.");
DEFINE_string(log_link_basename, "",
"Basename used for symlink creation to the last log file.");
DEFINE_uint64(memory_warning_threshold, 1024,
@ -52,7 +48,7 @@ DEFINE_uint64(memory_warning_threshold, 1024,
// 3) env - MEMGRAPH_CONFIG
// 4) command line flags
void load_config(int &argc, char **&argv) {
void LoadConfig(int &argc, char **&argv) {
std::vector<fs::path> configs = {fs::path("/etc/memgraph/memgraph.conf")};
if (getenv("HOME") != nullptr)
configs.emplace_back(fs::path(getenv("HOME")) /
@ -99,7 +95,7 @@ int main(int argc, char **argv) {
google::SetUsageMessage("Memgraph database server");
gflags::SetVersionString(version_string);
load_config(argc, argv);
LoadConfig(argc, argv);
google::InitGoogleLogging(argv[0]);
google::SetLogDestination(google::INFO, FLAGS_log_file.c_str());
@ -123,31 +119,19 @@ int main(int argc, char **argv) {
});
// Initialize bolt session data (Dbms and Interpreter).
session_data_t session_data;
SessionData session_data;
// Initialize endpoint.
endpoint_t endpoint;
try {
endpoint = endpoint_t(FLAGS_interface, FLAGS_port);
} catch (io::network::NetworkEndpointException &e) {
LOG(FATAL) << e.what();
}
// Initialize socket.
socket_t socket;
if (!socket.Bind(endpoint)) {
LOG(FATAL) << "Cannot bind to socket on " << FLAGS_interface << " at "
<< FLAGS_port;
}
if (!socket.SetNonBlocking()) {
LOG(FATAL) << "Cannot set socket to non blocking!";
}
if (!socket.Listen(1024)) {
LOG(FATAL) << "Cannot listen on socket!";
}
NetworkEndpoint endpoint = [&] {
try {
return NetworkEndpoint(FLAGS_interface, FLAGS_port);
} catch (io::network::NetworkEndpointException &e) {
LOG(FATAL) << e.what();
}
}();
// Initialize server.
bolt_server_t server(std::move(socket), session_data);
ServerT server(endpoint, session_data);
// register SIGTERM handler
SignalHandler::register_handler(Signal::Terminate,

View File

@ -1,6 +1,7 @@
#pragma once
#include <array>
#include <chrono>
#include <cstring>
#include <iostream>
#include <random>
@ -11,24 +12,26 @@
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/server.hpp"
#include "database/graph_db_accessor.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
static constexpr const int SIZE = 60000;
static constexpr const int REPLY = 10;
using endpoint_t = io::network::NetworkEndpoint;
using socket_t = io::network::Socket;
using io::network::NetworkEndpoint;
using io::network::Socket;
class TestData {};
class TestSession {
public:
TestSession(socket_t &&socket, TestData &) : socket_(std::move(socket)) {
TestSession(Socket &&socket, TestData &) : socket_(std::move(socket)) {
event_.data.ptr = this;
}
bool Alive() { return socket_.IsOpen(); }
bool Alive() const { return socket_.IsOpen(); }
bool TimedOut() const { return false; }
int Id() const { return socket_.fd(); }
@ -56,15 +59,12 @@ class TestSession {
}
communication::bolt::Buffer<SIZE * 2> buffer_;
socket_t socket_;
Socket socket_;
io::network::Epoll::Event event_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
};
using test_server_t = communication::Server<TestSession, socket_t, TestData>;
void server_start(void *serverptr, int num) {
((test_server_t *)serverptr)->Start(num);
}
using ServerT = communication::Server<TestSession, TestData>;
void client_run(int num, const char *interface, const char *port,
const unsigned char *data, int lo, int hi) {
@ -72,8 +72,8 @@ void client_run(int num, const char *interface, const char *port,
name << "Client " << num;
unsigned char buffer[SIZE * REPLY], head[2];
int have, read;
endpoint_t endpoint(interface, port);
socket_t socket;
NetworkEndpoint endpoint(interface, port);
Socket socket;
ASSERT_TRUE(socket.Connect(endpoint));
ASSERT_TRUE(socket.SetTimeout(2, 0));
DLOG(INFO) << "Socket create: " << socket.fd();

View File

@ -14,23 +14,26 @@
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/server.hpp"
#include "database/graph_db_accessor.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
static constexpr const char interface[] = "127.0.0.1";
using endpoint_t = io::network::NetworkEndpoint;
using socket_t = io::network::Socket;
using io::network::NetworkEndpoint;
using io::network::Socket;
class TestData {};
class TestSession {
public:
TestSession(socket_t &&socket, TestData &) : socket_(std::move(socket)) {
TestSession(Socket &&socket, TestData &) : socket_(std::move(socket)) {
event_.data.ptr = this;
}
bool Alive() { return socket_.IsOpen(); }
bool Alive() const { return socket_.IsOpen(); }
bool TimedOut() const { return false; }
int Id() const { return socket_.fd(); }
@ -42,19 +45,17 @@ class TestSession {
void Close() { this->socket_.Close(); }
socket_t socket_;
Socket socket_;
communication::bolt::Buffer<> buffer_;
io::network::Epoll::Event event_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
};
using test_server_t = communication::Server<TestSession, socket_t, TestData>;
test_server_t *serverptr;
std::atomic<bool> run{true};
void client_run(int num, const char *interface, const char *port) {
endpoint_t endpoint(interface, port);
socket_t socket;
NetworkEndpoint endpoint(interface, port);
Socket socket;
uint8_t data = 0x00;
ASSERT_TRUE(socket.Connect(endpoint));
ASSERT_TRUE(socket.SetTimeout(1, 0));
@ -67,32 +68,22 @@ void client_run(int num, const char *interface, const char *port) {
socket.Close();
}
void server_run(void *serverptr, int num) {
((test_server_t *)serverptr)->Start(num);
}
TEST(Network, SocketReadHangOnConcurrentConnections) {
// initialize listen socket
endpoint_t endpoint(interface, "0");
socket_t socket;
ASSERT_TRUE(socket.Bind(endpoint));
ASSERT_TRUE(socket.SetNonBlocking());
ASSERT_TRUE(socket.Listen(1024));
NetworkEndpoint endpoint(interface, "0");
// get bound address
auto ep = socket.endpoint();
printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port());
printf("ADDRESS: %s, PORT: %d\n", endpoint.address(), endpoint.port());
// initialize server
TestData data;
test_server_t server(std::move(socket), data);
serverptr = &server;
communication::Server<TestSession, TestData> server(endpoint, data);
// start server
int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3;
std::thread server_thread(server_run, serverptr, N);
std::thread server_thread([&] { server.Start(N); });
const auto &ep = server.endpoint();
// start clients
std::vector<std::thread> clients;
for (int i = 0; i < Nc; ++i)

View File

@ -8,32 +8,23 @@ static constexpr const char interface[] = "127.0.0.1";
unsigned char data[SIZE];
test_server_t *serverptr;
TEST(Network, Server) {
// initialize test data
initialize_data(data, SIZE);
// initialize listen socket
endpoint_t endpoint(interface, "0");
socket_t socket;
ASSERT_TRUE(socket.Bind(endpoint));
ASSERT_TRUE(socket.SetNonBlocking());
ASSERT_TRUE(socket.Listen(1024));
// get bound address
auto ep = socket.endpoint();
printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port());
NetworkEndpoint endpoint(interface, "0");
printf("ADDRESS: %s, PORT: %d\n", endpoint.address(), endpoint.port());
// initialize server
TestData session_data;
test_server_t server(std::move(socket), session_data);
serverptr = &server;
ServerT server(endpoint, session_data);
// start server
int N = (std::thread::hardware_concurrency() + 1) / 2;
std::thread server_thread(server_start, serverptr, N);
std::thread server_thread([&] { server.Start(N); });
const auto &ep = server.endpoint();
// start clients
std::vector<std::thread> clients;
for (int i = 0; i < N; ++i)

View File

@ -10,8 +10,6 @@ static constexpr const char interface[] = "127.0.0.1";
unsigned char data[SIZE];
test_server_t *serverptr;
using namespace std::chrono_literals;
TEST(Network, SessionLeak) {
@ -19,28 +17,21 @@ TEST(Network, SessionLeak) {
initialize_data(data, SIZE);
// initialize listen socket
endpoint_t endpoint(interface, "0");
socket_t socket;
ASSERT_TRUE(socket.Bind(endpoint));
ASSERT_TRUE(socket.SetNonBlocking());
ASSERT_TRUE(socket.Listen(1024));
// get bound address
auto ep = socket.endpoint();
printf("ADDRESS: %s, PORT: %d\n", ep.address(), ep.port());
NetworkEndpoint endpoint(interface, "0");
printf("ADDRESS: %s, PORT: %d\n", endpoint.address(), endpoint.port());
// initialize server
TestData session_data;
test_server_t server(std::move(socket), session_data);
serverptr = &server;
ServerT server(endpoint, session_data);
// start server
std::thread server_thread(server_start, serverptr, 2);
std::thread server_thread([&] { server.Start(2); });
// start clients
int N = 50;
std::vector<std::thread> clients;
const auto &ep = server.endpoint();
int testlen = 3000;
for (int i = 0; i < N; ++i) {
clients.push_back(std::thread(client_run, i, interface, ep.port_str(), data,

View File

@ -1,5 +1,6 @@
#include <array>
#include <cstring>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
@ -13,26 +14,19 @@
*/
class TestSocket {
public:
TestSocket(int socket) : socket(socket) {}
TestSocket(const TestSocket &s) : socket(s.id()){};
TestSocket(TestSocket &&other) { *this = std::forward<TestSocket>(other); }
explicit TestSocket(int socket) : socket_(socket) {}
TestSocket(const TestSocket &) = default;
TestSocket &operator=(const TestSocket &) = default;
TestSocket(TestSocket &&) = default;
TestSocket &operator=(TestSocket &&) = default;
TestSocket &operator=(TestSocket &&other) {
this->socket = other.socket;
other.socket = -1;
return *this;
}
void Close() { socket_ = -1; }
bool IsOpen() { return socket_ != -1; }
void Close() { socket = -1; }
bool IsOpen() { return socket != -1; }
int id() const { return socket_; }
int id() const { return socket; }
bool Write(const std::string &str) { return Write(str.c_str(), str.size()); }
bool Write(const char *data, size_t len) {
return Write(reinterpret_cast<const uint8_t *>(data), len);
}
bool Write(const uint8_t *data, size_t len) {
bool Write(const uint8_t *data, size_t len,
const std::function<bool()> & = [] { return false; }) {
if (!write_success_) return false;
for (size_t i = 0; i < len; ++i) output.push_back(data[i]);
return true;
@ -43,7 +37,7 @@ class TestSocket {
std::vector<uint8_t> output;
protected:
int socket;
int socket_;
bool write_success_{true};
};

View File

@ -7,18 +7,16 @@
// TODO: This could be done in fixture.
// Shortcuts for writing variable initializations in tests
#define INIT_VARS \
TestSocket socket(10); \
SessionDataT session_data; \
SessionT session(std::move(socket), session_data); \
#define INIT_VARS \
TestSocket socket(10); \
SessionData session_data; \
SessionT session(std::move(socket), session_data); \
std::vector<uint8_t> &output = session.socket_.output;
using ResultStreamT =
communication::bolt::ResultStream<communication::bolt::Encoder<
communication::bolt::ChunkedEncoderBuffer<TestSocket>>>;
using SessionDataT = communication::bolt::SessionData<ResultStreamT>;
using communication::bolt::State;
using communication::bolt::SessionData;
using SessionT = communication::bolt::Session<TestSocket>;
using StateT = communication::bolt::State;
using ResultStreamT = SessionT::ResultStreamT;
// Sample testdata that has correct inputs and outputs.
const uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00,
@ -88,7 +86,7 @@ void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) {
session.Written(20);
session.Execute();
ASSERT_EQ(session.state_, StateT::Init);
ASSERT_EQ(session.state_, State::Init);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
@ -108,7 +106,7 @@ void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len,
// Execute and check a correct init
void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) {
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
CheckOutput(output, init_resp, 7);
@ -151,7 +149,7 @@ TEST(BoltSession, HandshakeWrongPreamble) {
session.Written(20);
session.Execute();
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
PrintOutput(output);
CheckFailureMessage(output);
@ -165,14 +163,14 @@ TEST(BoltSession, HandshakeInTwoPackets) {
session.Written(10);
session.Execute();
ASSERT_EQ(session.state_, StateT::Handshake);
ASSERT_EQ(session.state_, State::Handshake);
ASSERT_TRUE(session.socket_.IsOpen());
memcpy(buff.data + 10, handshake_req + 10, 10);
session.Written(10);
session.Execute();
ASSERT_EQ(session.state_, StateT::Init);
ASSERT_EQ(session.state_, State::Init);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
@ -183,7 +181,7 @@ TEST(BoltSession, HandshakeWriteFail) {
session.socket_.SetWriteSuccess(false);
ExecuteCommand(session, handshake_req, sizeof(handshake_req), false);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -198,7 +196,7 @@ TEST(BoltSession, InitWrongSignature) {
ExecuteHandshake(session, output);
ExecuteCommand(session, run_req_header, sizeof(run_req_header));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -211,7 +209,7 @@ TEST(BoltSession, InitWrongMarker) {
uint8_t data[2] = {0x00, init_req[1]};
ExecuteCommand(session, data, 2);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -226,7 +224,7 @@ TEST(BoltSession, InitMissingData) {
ExecuteHandshake(session, output);
ExecuteCommand(session, init_req, len[i]);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -238,7 +236,7 @@ TEST(BoltSession, InitWriteFail) {
session.socket_.SetWriteSuccess(false);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -259,7 +257,7 @@ TEST(BoltSession, ExecuteRunWrongMarker) {
uint8_t data[2] = {0x00, run_req_header[1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -275,7 +273,7 @@ TEST(BoltSession, ExecuteRunMissingData) {
ExecuteInit(session, output);
ExecuteCommand(session, run_req_header, len[i]);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -294,11 +292,11 @@ TEST(BoltSession, ExecuteRunBasicException) {
session.Execute();
if (i == 0) {
ASSERT_EQ(session.state_, StateT::ErrorIdle);
ASSERT_EQ(session.state_, State::ErrorIdle);
ASSERT_TRUE(session.socket_.IsOpen());
CheckFailureMessage(output);
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -314,7 +312,7 @@ TEST(BoltSession, ExecuteRunWithoutPullAll) {
WriteRunRequest(session, "RETURN 2");
session.Execute();
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
}
TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
@ -332,7 +330,7 @@ TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
uint8_t data[2] = {0x00, dataset[i][1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -349,7 +347,7 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) {
session.socket_.SetWriteSuccess(i == 0);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
if (i == 0) {
CheckFailureMessage(output);
@ -380,12 +378,12 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) {
ExecuteCommand(session, dataset[i], 2);
if (j == 0) {
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_FALSE(session.encoder_buffer_.HasData());
PrintOutput(output);
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -400,7 +398,7 @@ TEST(BoltSession, ExecuteInvalidMessage) {
ExecuteInit(session, output);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -425,11 +423,11 @@ TEST(BoltSession, ErrorIgnoreMessage) {
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (i == 0) {
ASSERT_EQ(session.state_, StateT::ErrorIdle);
ASSERT_EQ(session.state_, State::ErrorIdle);
ASSERT_TRUE(session.socket_.IsOpen());
CheckOutput(output, ignored_resp, sizeof(ignored_resp));
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -451,13 +449,13 @@ TEST(BoltSession, ErrorRunAfterRun) {
session.socket_.SetWriteSuccess(true);
// Session holds results of last run.
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
// New run request.
WriteRunRequest(session, "MATCH (n) RETURN n");
session.Execute();
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
}
@ -475,7 +473,7 @@ TEST(BoltSession, ErrorCantCleanup) {
// there is data missing in the request, cleanup should fail
ExecuteCommand(session, init_req, sizeof(init_req) - 10);
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -495,7 +493,7 @@ TEST(BoltSession, ErrorWrongMarker) {
uint8_t data[2] = {0x00, init_req[1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -524,11 +522,11 @@ TEST(BoltSession, ErrorOK) {
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -551,7 +549,7 @@ TEST(BoltSession, ErrorMissingData) {
uint8_t data[1] = {0x00};
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);
}
@ -565,7 +563,7 @@ TEST(BoltSession, MultipleChunksInOneExecute) {
WriteRunRequest(session, "CREATE (n) RETURN n");
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
PrintOutput(output);
@ -597,7 +595,7 @@ TEST(BoltSession, PartialChunk) {
// missing chunk tail
session.Execute();
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_EQ(output.size(), 0);
@ -605,7 +603,7 @@ TEST(BoltSession, PartialChunk) {
session.Execute();
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_GT(output.size(), 0);
PrintOutput(output);
@ -624,25 +622,25 @@ TEST(BoltSession, ExplicitTransactionValidQueries) {
WriteRunRequest(session, "BEGIN");
session.Execute();
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
WriteRunRequest(session, "MATCH (n) RETURN n");
session.Execute();
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
@ -650,11 +648,11 @@ TEST(BoltSession, ExplicitTransactionValidQueries) {
session.Execute();
ASSERT_FALSE(session.db_accessor_);
CheckSuccessMessage(output);
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_);
CheckSuccessMessage(output);
@ -673,31 +671,31 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) {
WriteRunRequest(session, "BEGIN");
session.Execute();
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
WriteRunRequest(session, "MATCH (");
session.Execute();
ASSERT_EQ(session.state_, StateT::ErrorWaitForRollback);
ASSERT_EQ(session.state_, State::ErrorWaitForRollback);
ASSERT_TRUE(session.db_accessor_);
CheckFailureMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::ErrorWaitForRollback);
ASSERT_EQ(session.state_, State::ErrorWaitForRollback);
ASSERT_TRUE(session.db_accessor_);
CheckIgnoreMessage(output);
ExecuteCommand(session, ackfailure_req, sizeof(ackfailure_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::WaitForRollback);
ASSERT_EQ(session.state_, State::WaitForRollback);
ASSERT_TRUE(session.db_accessor_);
CheckSuccessMessage(output);
@ -705,20 +703,20 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) {
session.Execute();
if (transaction_end == "ROLLBACK") {
ASSERT_EQ(session.state_, StateT::Result);
ASSERT_EQ(session.state_, State::Result);
ASSERT_FALSE(session.db_accessor_);
ASSERT_TRUE(session.socket_.IsOpen());
CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, StateT::Idle);
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_);
ASSERT_TRUE(session.socket_.IsOpen());
CheckSuccessMessage(output);
} else {
ASSERT_EQ(session.state_, StateT::Close);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.db_accessor_);
ASSERT_FALSE(session.socket_.IsOpen());
CheckFailureMessage(output);

View File

@ -0,0 +1,104 @@
#include <chrono>
#include <experimental/filesystem>
#include <iostream>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "gtest/gtest.h"
#include "communication/bolt/client.hpp"
#include "communication/bolt/v1/session.hpp"
#include "communication/server.hpp"
#include "io/network/network_endpoint.hpp"
#include "io/network/socket.hpp"
DECLARE_int32(query_execution_time_sec);
DECLARE_int32(session_inactivity_timeout);
using namespace std::chrono_literals;
class TestClientSocket;
using io::network::NetworkEndpoint;
using io::network::Socket;
using communication::bolt::SessionData;
using communication::bolt::ClientException;
using SessionT = communication::bolt::Session<Socket>;
using ResultStreamT = SessionT::ResultStreamT;
using ServerT = communication::Server<SessionT, SessionData>;
using ClientT = communication::bolt::Client<Socket>;
class RunningServer {
public:
~RunningServer() {
server_.Shutdown();
server_thread_.join();
}
SessionData session_data_;
NetworkEndpoint endpoint_{"127.0.0.1", "0"};
ServerT server_{endpoint_, session_data_};
std::thread server_thread_{[&] { server_.Start(1); }};
};
class TestClient : public ClientT {
public:
TestClient(NetworkEndpoint endpoint)
: ClientT(
[&] {
Socket socket;
socket.Connect(endpoint);
return socket;
}(),
"", "") {}
};
TEST(NetworkTimeouts, InactiveSession) {
FLAGS_query_execution_time_sec = 60;
FLAGS_session_inactivity_timeout = 1;
RunningServer rs;
TestClient client(rs.server_.endpoint());
// Check that we can execute first query.
client.Execute("RETURN 1", {});
// After sleep, session should still be alive.
std::this_thread::sleep_for(500ms);
client.Execute("RETURN 1", {});
// After sleep, session should still be alive.
std::this_thread::sleep_for(500ms);
client.Execute("RETURN 1", {});
// After sleep, session should still be alive.
std::this_thread::sleep_for(500ms);
client.Execute("RETURN 1", {});
// After sleep, session should have timed out.
std::this_thread::sleep_for(1500ms);
EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException);
}
TEST(NetworkTimeouts, TimeoutInMultiCommandTransaction) {
FLAGS_query_execution_time_sec = 1;
FLAGS_session_inactivity_timeout = 60;
RunningServer rs;
TestClient client(rs.server_.endpoint());
// Start explicit multicommand transaction.
client.Execute("BEGIN", {});
client.Execute("RETURN 1", {});
// Session should still be alive.
std::this_thread::sleep_for(500ms);
client.Execute("RETURN 1", {});
// Session shouldn't be alive anymore.
std::this_thread::sleep_for(2s);
EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
google::InitGoogleLogging(argv[0]);
return RUN_ALL_TESTS();
}