Refactor network stack

Summary:
Previously, the network stack `communication::Server` accepted connections and
assigned them statically in a round-robin fashion to `communication::Worker`.
That meant that if two compute intensive connections were assigned to the same
worker they would block each other while the other workers would do nothing.

This implementation replaces `communication::Worker` with
`communication::Listener` which holds all accepted connections in one pool and
ensures that all workers execute all connections.

Reviewers: buda, florijan, teon.banek

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1220
This commit is contained in:
Matej Ferencevic 2018-02-22 16:17:45 +01:00
parent fc75fadee3
commit 017e8004e8
20 changed files with 462 additions and 563 deletions

View File

@ -1,5 +1,7 @@
#pragma once
#include <thread>
#include "glog/logging.h"
#include "communication/bolt/v1/constants.hpp"
@ -17,7 +19,9 @@
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
#include "query/interpreter.hpp"
#include "threading/sync/spinlock.hpp"
#include "transactions/transaction.hpp"
#include "utils/exceptions.hpp"
DECLARE_int32(session_inactivity_timeout);
@ -30,6 +34,16 @@ struct SessionData {
query::Interpreter interpreter;
};
/**
* Bolt Session Exception
*
* Used to indicate that something went wrong during the session execution.
*/
class SessionException : public utils::BasicException {
public:
using utils::BasicException::BasicException;
};
/**
* Bolt Session
*
@ -40,25 +54,7 @@ struct SessionData {
template <typename TSocket>
class Session {
public:
// Wrapper around socket that checks if session has timed out on write
// failures, used in encoder buffer.
class TimeoutSocket {
public:
explicit TimeoutSocket(Session &session) : session_(session) {}
bool Write(const uint8_t *data, size_t len) {
// The have_more flag is hardcoded to false here because the bolt data
// is internally buffered and doesn't need to be buffered by the kernel.
return session_.socket_.Write(data, len, false,
[this] { return !session_.TimedOut(); });
}
private:
Session &session_;
};
using ResultStreamT =
ResultStream<Encoder<ChunkedEncoderBuffer<TimeoutSocket>>>;
using ResultStreamT = ResultStream<Encoder<ChunkedEncoderBuffer<TSocket>>>;
using StreamBuffer = io::network::StreamBuffer;
Session(TSocket &&socket, SessionData &data)
@ -67,8 +63,9 @@ class Session {
interpreter_(data.interpreter) {}
~Session() {
DCHECK(!db_accessor_)
<< "Transaction should have already be closed in Close";
if (db_accessor_) {
Abort();
}
}
/**
@ -158,8 +155,12 @@ class Session {
* Returns true if session has timed out. Session times out if there was no
* activity in FLAGS_sessions_inactivity_timeout seconds or if there is a
* active transaction with shoul_abort flag set to true.
* This function must be thread safe because this function and
* `RefreshLastEventTime` are called from different threads in the
* network stack.
*/
bool TimedOut() const {
bool TimedOut() {
std::unique_lock<SpinLock> guard(lock_);
return db_accessor_
? db_accessor_->should_abort()
: last_event_time_ + std::chrono::seconds(
@ -167,17 +168,6 @@ class Session {
std::chrono::steady_clock::now();
}
/**
* Closes the session (client socket).
*/
void Close() {
DLOG(INFO) << "Closing session";
if (db_accessor_) {
Abort();
}
this->socket_.Close();
}
/**
* Commits associated transaction.
*/
@ -198,9 +188,16 @@ class Session {
TSocket &socket() { return socket_; }
/**
* Function that is called by the network stack to set the last event time.
* It is used to determine whether the session has timed out.
* This function must be thread safe because this function and
* `TimedOut` are called from different threads in the network stack.
*/
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
std::unique_lock<SpinLock> guard(lock_);
last_event_time_ = last_event_time;
}
@ -210,9 +207,8 @@ class Session {
database::MasterBase &db_;
query::Interpreter &interpreter_;
TimeoutSocket timeout_socket_{*this};
ChunkedEncoderBuffer<TimeoutSocket> encoder_buffer_{timeout_socket_};
Encoder<ChunkedEncoderBuffer<TimeoutSocket>> encoder_{encoder_buffer_};
ChunkedEncoderBuffer<TSocket> encoder_buffer_{socket_};
Encoder<ChunkedEncoderBuffer<TSocket>> encoder_{encoder_buffer_};
ResultStreamT output_stream_{encoder_};
Buffer<> buffer_;
@ -224,9 +220,11 @@ class Session {
// GraphDbAccessor of active transaction in the session, can be null if
// there is no associated transaction.
std::unique_ptr<database::GraphDbAccessor> db_accessor_;
// Time of the last event.
// Time of the last event and associated lock.
std::chrono::time_point<std::chrono::steady_clock> last_event_time_ =
std::chrono::steady_clock::now();
SpinLock lock_;
private:
void ClientFailureInvalidData() {
@ -235,10 +233,17 @@ class Session {
// We don't care about the return status because this is called when we
// are about to close the connection to the client.
encoder_buffer_.Clear();
encoder_.MessageFailure({{"code", "Memgraph.InvalidData"},
{"message", "The client has sent invalid data!"}});
// Close the connection.
Close();
encoder_.MessageFailure({{"code", "Memgraph.ExecutionException"},
{"message",
"Something went wrong while executing the query! "
"Check the server logs for more details."}});
// Throw an exception to indicate that something went wrong with execution
// of the session to trigger session cleanup and socket close.
if (TimedOut()) {
throw SessionException("The session has timed out!");
} else {
throw SessionException("The client has sent invalid data!");
}
}
};
} // namespace communication::bolt

View File

@ -25,7 +25,7 @@ State StateHandshakeRun(TSession &session) {
// make sense to check which version the client prefers this will change in
// the future.
if (!session.timeout_socket_.Write(kProtocol, sizeof(kProtocol))) {
if (!session.socket_.Write(kProtocol, sizeof(kProtocol))) {
DLOG(WARNING) << "Couldn't write handshake response!";
return State::Close;
}

View File

@ -0,0 +1,246 @@
#pragma once
#include <algorithm>
#include <atomic>
#include <chrono>
#include <memory>
#include <mutex>
#include <thread>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "threading/sync/spinlock.hpp"
namespace communication {
/**
* This class listens to events on an epoll object and processes them.
* When a new connection is added a `TSession` object is created to handle the
* connection. When the `TSession` handler raises an exception or an error
* occurs the `TSession` object is deleted and the corresponding socket is
* closed. Also, this class has a background thread that periodically, every
* second, checks all sessions for expiration and shuts them down if they have
* expired.
*/
template <class TSession, class TSessionData>
class Listener {
private:
// The maximum number of events handled per execution thread is 1. This is
// because each event represents the start of a network request and it doesn't
// make sense to take more than one event because the processing of an event
// can take a long time.
static const int kMaxEvents = 1;
public:
Listener(TSessionData &data, bool check_for_timeouts)
: data_(data), alive_(true) {
if (check_for_timeouts) {
thread_ = std::thread([this]() {
while (alive_) {
{
std::unique_lock<SpinLock> guard(lock_);
for (auto &session : sessions_) {
if (session->TimedOut()) {
LOG(WARNING) << "Session associated with "
<< session->socket().endpoint() << " timed out.";
// Here we shutdown the socket to terminate any leftover
// blocking `Write` calls and to signal an event that the
// session is closed. Session cleanup will be done in the event
// process function.
session->socket().Shutdown();
}
}
}
// TODO (mferencevic): Should this be configurable?
std::this_thread::sleep_for(std::chrono::seconds(1));
}
});
}
}
~Listener() {
alive_.store(false);
if (thread_.joinable()) thread_.join();
}
Listener(const Listener &) = delete;
Listener(Listener &&) = delete;
Listener &operator=(const Listener &) = delete;
Listener &operator=(Listener &&) = delete;
/**
* This function adds a socket to the listening event pool.
*
* @param connection socket which should be added to the event pool
*/
void AddConnection(io::network::Socket &&connection) {
std::unique_lock<SpinLock> guard(lock_);
// Set connection options.
// The socket is left to be a blocking socket, but when `Read` is called
// then a flag is manually set to enable non-blocking read that is used in
// conjunction with `EPOLLET`. That means that the socket is used in a
// non-blocking fashion for reads and a blocking fashion for writes.
connection.SetKeepAlive();
connection.SetNoDelay();
// Remember fd before moving connection into Session.
int fd = connection.fd();
// Create a new Session for the connection.
sessions_.push_back(
std::make_unique<TSession>(std::move(connection), data_));
// Register the connection in Epoll.
// We want to listen to an incoming event which is edge triggered and
// we also want to listen on the hangup event. Epoll is hard to use
// concurrently and that is why we use `EPOLLONESHOT`, for a detailed
// description what are the problems and why this is correct see:
// https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/
epoll_.Add(fd, EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT,
sessions_.back().get());
}
/**
* This function polls the event queue and processes incoming data.
* It is thread safe and is intended to be called from multiple threads and
* doesn't block the calling threads.
*/
void WaitAndProcessEvents() {
// This array can't be global because this function can be called from
// multiple threads, therefore, it must be on the stack.
io::network::Epoll::Event events[kMaxEvents];
// Waits for an events and returns a maximum of max_events (1)
// and stores them in the events array. It waits for wait_timeout
// milliseconds. If wait_timeout is achieved, returns 0.
int n = epoll_.Wait(events, kMaxEvents, 200);
if (n <= 0) return;
// Process the event.
auto &event = events[0];
// We get the currently associated Session pointer and immediately
// dereference it here. It is safe to dereference the pointer because
// this design guarantees that there will never be an event that has
// a stale Session pointer.
TSession &session = *reinterpret_cast<TSession *>(event.data.ptr);
// Process epoll events. We use epoll in edge-triggered mode so we process
// all events here. Only one of the `if` statements must be executed
// because each of them can call `CloseSession` which destroys the session
// and calling a function on that session after that would cause a
// segfault.
if (event.events & EPOLLIN) {
// Read and process all incoming data.
while (ReadAndProcessSession(session))
;
} else if (event.events & EPOLLRDHUP) {
// The client closed the connection.
LOG(INFO) << "Client " << session.socket().endpoint()
<< " closed the connection.";
CloseSession(session);
} else if (!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR)) {
// There was an error on the server side.
LOG(ERROR) << "Error occured in session associated with "
<< session.socket().endpoint();
CloseSession(session);
} else {
// Unhandled epoll event.
LOG(ERROR) << "Unhandled event occured in session associated with "
<< session.socket().endpoint() << " events: " << event.events;
CloseSession(session);
}
}
private:
bool ReadAndProcessSession(TSession &session) {
// Refresh the last event time in the session.
// This function must be implemented thread safe.
session.RefreshLastEventTime(std::chrono::steady_clock::now());
// Allocate the buffer to fill the data.
auto buf = session.Allocate();
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
int len = session.socket().Read(buf.data, buf.len, true);
// Check for read errors.
if (len == -1) {
// This means read would block or read was interrupted by signal, we
// return `false` to stop reading of data.
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
// Rearm epoll to send events from this socket.
epoll_.Modify(session.socket().fd(),
EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session);
return false;
}
// Some other error occurred, close the session.
CloseSession(session);
return false;
}
// The client has closed the connection.
if (len == 0) {
LOG(INFO) << "Client " << session.socket().endpoint()
<< " closed the connection.";
CloseSession(session);
return false;
}
// Notify the session that it has new data.
session.Written(len);
// Execute the session.
try {
session.Execute();
session.RefreshLastEventTime(std::chrono::steady_clock::now());
} catch (const std::exception &e) {
// Catch all exceptions.
LOG(ERROR) << "Exception was thrown while processing event in session "
"associated with "
<< session.socket().endpoint()
<< " with message: " << e.what();
CloseSession(session);
return false;
}
return true;
}
void CloseSession(TSession &session) {
// Deregister the Session's socket from epoll to disable further events. For
// a detailed description why this is necessary before destroying (closing)
// the socket, see:
// https://idea.popcount.org/2017-03-20-epoll-is-fundamentally-broken-22/
epoll_.Delete(session.socket().fd());
std::unique_lock<SpinLock> guard(lock_);
auto it =
std::find_if(sessions_.begin(), sessions_.end(),
[&](const auto &l) { return l->Id() == session.Id(); });
CHECK(it != sessions_.end())
<< "Trying to remove session that is not found in sessions!";
int i = it - sessions_.begin();
swap(sessions_[i], sessions_.back());
// This will call all destructors on the Session. Consequently, it will call
// the destructor on the Socket and close the socket.
sessions_.pop_back();
}
io::network::Epoll epoll_;
TSessionData &data_;
SpinLock lock_;
std::vector<std::unique_ptr<TSession>> sessions_;
std::thread thread_;
std::atomic<bool> alive_;
};
} // namespace communication

View File

@ -16,8 +16,6 @@ namespace communication::rpc {
Session::Session(Socket &&socket, System &system)
: socket_(std::make_shared<Socket>(std::move(socket))), system_(system) {}
bool Session::Alive() const { return alive_; }
void Session::Execute() {
if (!handshake_done_) {
if (buffer_.size() < sizeof(MessageSize)) return;
@ -59,14 +57,6 @@ StreamBuffer Session::Allocate() { return buffer_.Allocate(); }
void Session::Written(size_t len) { buffer_.Written(len); }
void Session::Close() {
DLOG(INFO) << "Closing session";
// We explicitly close the socket here to remove the socket from the epoll
// event loop. The response message send will fail but that is OK and
// intended because the remote side closed the connection.
socket_.get()->Close();
}
void SendMessage(Socket &socket, uint32_t message_id,
std::unique_ptr<Message> &message) {
CHECK(message) << "Trying to send nullptr instead of message";

View File

@ -41,11 +41,6 @@ class Session {
int Id() const { return socket_->fd(); }
/**
* Returns the protocol alive state
*/
bool Alive() const;
/**
* Executes the protocol after data has been read into the buffer.
* Goes through the protocol states in order to execute commands from the
@ -71,11 +66,6 @@ class Session {
bool TimedOut() { return false; }
/**
* Closes the session (client socket).
*/
void Close();
Socket &socket() { return *socket_; }
void RefreshLastEventTime(
@ -93,7 +83,6 @@ class Session {
std::string service_name_;
bool handshake_done_{false};
bool alive_{true};
Buffer buffer_;
};

View File

@ -13,7 +13,7 @@ namespace communication::rpc {
System::System(const io::network::Endpoint &endpoint,
const size_t workers_count)
: server_(endpoint, *this, workers_count) {}
: server_(endpoint, *this, false, workers_count) {}
System::~System() {}

View File

@ -10,24 +10,21 @@
#include <fmt/format.h>
#include <glog/logging.h>
#include "communication/worker.hpp"
#include "communication/listener.hpp"
#include "io/network/socket.hpp"
#include "io/network/socket_event_dispatcher.hpp"
namespace communication {
/**
* TODO (mferencevic): document methods
*/
/**
* Communication server.
* Listens for incomming connections on the server port and assigns them in a
* round-robin manner to it's workers. Started automatically on constructor, and
* stopped at destructor.
*
* Listens for incoming connections on the server port and assigns them to the
* connection listener. The listener processes the events with a thread pool
* that has `num_workers` threads. It is started automatically on constructor,
* and stopped at destructor.
*
* Current Server achitecture:
* incomming connection -> server -> worker -> session
* incoming connection -> server -> listener -> session
*
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
@ -38,7 +35,6 @@ namespace communication {
template <typename TSession, typename TSessionData>
class Server {
public:
using WorkerT = Worker<TSession, TSessionData>;
using Socket = io::network::Socket;
/**
@ -46,43 +42,35 @@ class Server {
* invokes workers_count workers
*/
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
bool check_for_timeouts,
size_t workers_count = std::thread::hardware_concurrency())
: session_data_(session_data) {
: listener_(session_data, check_for_timeouts) {
// Without server we can't continue with application so we can just
// terminate here.
if (!socket_.Bind(endpoint)) {
LOG(FATAL) << "Cannot bind to socket on " << endpoint.address() << " at "
<< endpoint.port();
LOG(FATAL) << "Cannot bind to socket on " << endpoint;
}
socket_.SetNonBlocking();
socket_.SetTimeout(1, 0);
if (!socket_.Listen(1024)) {
LOG(FATAL) << "Cannot listen on socket!";
}
working_thread_ = std::thread([this, workers_count]() {
thread_ = std::thread([this, workers_count]() {
std::cout << fmt::format("Starting {} workers", workers_count)
<< std::endl;
workers_.reserve(workers_count);
for (size_t i = 0; i < workers_count; ++i) {
workers_.push_back(std::make_unique<WorkerT>(session_data_));
worker_threads_.emplace_back(
[this](WorkerT &worker) -> void { worker.Start(alive_); },
std::ref(*workers_.back()));
worker_threads_.emplace_back([this]() {
while (alive_) {
listener_.WaitAndProcessEvents();
}
});
}
std::cout << "Server is fully armed and operational" << std::endl;
std::cout << fmt::format("Listening on {} at {}",
socket_.endpoint().address(),
socket_.endpoint().port())
<< std::endl;
std::vector<std::unique_ptr<ConnectionAcceptor>> acceptors;
acceptors.emplace_back(
std::make_unique<ConnectionAcceptor>(socket_, *this));
auto &acceptor = *acceptors.back().get();
io::network::SocketEventDispatcher<ConnectionAcceptor> dispatcher{
acceptors};
dispatcher.AddListener(socket_.fd(), acceptor, EPOLLIN);
std::cout << "Server is fully armed and operational" << std::endl;
std::cout << "Listening on " << socket_.endpoint() << std::endl;
while (alive_) {
dispatcher.WaitAndProcessEvents();
AcceptConnection();
}
std::cout << "Shutting down..." << std::endl;
@ -97,6 +85,11 @@ class Server {
AwaitShutdown();
}
Server(const Server &) = delete;
Server(Server &&) = delete;
Server &operator=(const Server &) = delete;
Server &operator=(Server &&) = delete;
const auto &endpoint() const { return socket_.endpoint(); }
/// Stops server manually
@ -104,73 +97,33 @@ class Server {
// This should be as simple as possible, so that it can be called inside a
// signal handler.
alive_.store(false);
// Shutdown the socket to return from any waiting `Accept` calls.
socket_.Shutdown();
}
/// Waits for the server to be signaled to shutdown
void AwaitShutdown() {
if (working_thread_.joinable()) working_thread_.join();
if (thread_.joinable()) thread_.join();
}
private:
class ConnectionAcceptor {
public:
ConnectionAcceptor(Socket &socket, Server<TSession, TSessionData> &server)
: socket_(socket), server_(server) {}
void OnData() {
DCHECK(server_.idx_ < server_.workers_.size()) << "Invalid worker id.";
DLOG(INFO) << "On connect";
auto connection = AcceptConnection();
if (!connection) {
// Connection is not available anymore or configuration failed.
return;
}
server_.workers_[server_.idx_]->AddConnection(std::move(*connection));
server_.idx_ = (server_.idx_ + 1) % server_.workers_.size();
void AcceptConnection() {
// Accept a connection from a socket.
auto s = socket_.Accept();
if (!s) {
// Connection is not available anymore or configuration failed.
return;
}
LOG(INFO) << "Accepted a connection from " << s->endpoint();
listener_.AddConnection(std::move(*s));
}
void OnClose() { socket_.Close(); }
void OnException(const std::exception &e) {
LOG(FATAL) << "Exception was thrown while processing event on socket "
<< socket_.fd() << " with message: " << e.what();
}
void OnError() { LOG(FATAL) << "Error on server side occured in epoll"; }
private:
// Accepts connection on socket_ and configures new connections. If done
// successfuly new socket (connection) is returner, nullopt otherwise.
std::experimental::optional<Socket> AcceptConnection() {
DLOG(INFO) << "Accept new connection on socket: " << socket_.fd();
// Accept a connection from a socket.
auto s = socket_.Accept();
if (!s) return std::experimental::nullopt;
DLOG(INFO) << fmt::format(
"Accepted a connection: socket {}, address '{}', family {}, port {}",
s->fd(), s->endpoint().address(), s->endpoint().family(),
s->endpoint().port());
s->SetTimeout(1, 0);
s->SetKeepAlive();
s->SetNoDelay();
return s;
}
Socket &socket_;
Server<TSession, TSessionData> &server_;
};
std::vector<std::unique_ptr<WorkerT>> workers_;
std::vector<std::thread> worker_threads_;
std::thread working_thread_;
std::atomic<bool> alive_{true};
int idx_{0};
std::thread thread_;
std::vector<std::thread> worker_threads_;
Socket socket_;
TSessionData &session_data_;
Listener<TSession, TSessionData> listener_;
};
} // namespace communication

View File

@ -1,194 +0,0 @@
#pragma once
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cstdio>
#include <iomanip>
#include <memory>
#include <mutex>
#include <sstream>
#include <thread>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "io/network/network_error.hpp"
#include "io/network/socket.hpp"
#include "io/network/socket_event_dispatcher.hpp"
#include "io/network/stream_buffer.hpp"
#include "threading/sync/spinlock.hpp"
namespace communication {
/**
* TODO (mferencevic): document methods
*/
/**
* Communication worker.
* Listens for incomming data on connections and accepts new connections.
* Also, executes sessions on incomming data.
*
* @tparam TSession the worker can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
* @tparam TSessionData the class with objects that will be forwarded to the
* session
*/
template <typename TSession, typename TSessionData>
class Worker {
using Socket = io::network::Socket;
public:
void AddConnection(Socket &&connection) {
std::unique_lock<SpinLock> guard(lock_);
// Remember fd before moving connection into SessionListener.
int fd = connection.fd();
session_listeners_.push_back(
std::make_unique<SessionSocketListener>(std::move(connection), *this));
// We want to listen to an incoming event which is edge triggered and
// we also want to listen on the hangup event.
dispatcher_.AddListener(fd, *session_listeners_.back(),
EPOLLIN | EPOLLRDHUP);
}
explicit Worker(TSessionData &session_data) : session_data_(session_data) {}
void Start(std::atomic<bool> &alive) {
while (alive) {
dispatcher_.WaitAndProcessEvents();
bool check_sessions_for_timeouts = true;
while (check_sessions_for_timeouts) {
check_sessions_for_timeouts = false;
std::unique_lock<SpinLock> guard(lock_);
for (auto &session_listener : session_listeners_) {
if (session_listener->session().TimedOut()) {
// We need to unlock here, because OnSessionAndTxTimeout will need
// to acquire same lock.
guard.unlock();
session_listener->OnSessionAndTxTimeout();
// Since we released lock we can't continue iteration so we need to
// break. There could still be more sessions that timed out so we
// set check_sessions_for_timeout back to true.
check_sessions_for_timeouts = true;
break;
}
}
}
}
}
private:
// TODO: Think about ownership. Who should own socket session,
// SessionSocketListener or Worker?
class SessionSocketListener {
public:
SessionSocketListener(Socket &&socket,
Worker<TSession, TSessionData> &worker)
: session_(std::move(socket), worker.session_data_), worker_(worker) {}
auto &session() { return session_; }
const auto &session() const { return session_; }
const auto &TimedOut() const { return session_.TimedOut(); }
void OnData() {
session_.RefreshLastEventTime(std::chrono::steady_clock::now());
DLOG(INFO) << "On data";
// allocate the buffer to fill the data
auto buf = session_.Allocate();
// read from the buffer at most buf.len bytes
int len = session_.socket().Read(buf.data, buf.len);
// check for read errors
if (len == -1) {
// This means read would block or read was interrupted by signal.
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
return;
}
// Some other error occurred, check errno.
OnError();
return;
}
// The client has closed the connection.
if (len == 0) {
DLOG(WARNING) << "Calling OnClose because the socket is closed!";
OnClose();
return;
}
// Notify the stream that it has new data.
session_.Written(len);
DLOG(INFO) << "OnRead";
try {
session_.Execute();
} catch (const std::exception &e) {
LOG(ERROR) << "Error occured while executing statement with message: "
<< e.what();
OnError();
}
session_.RefreshLastEventTime(std::chrono::steady_clock::now());
}
// TODO: Remove duplication in next three functions.
void OnError() {
LOG(ERROR) << fmt::format(
"Error occured in session associated with {}:{}",
session_.socket().endpoint().address(),
session_.socket().endpoint().port());
CloseSession();
}
void OnException(const std::exception &e) {
LOG(ERROR) << fmt::format(
"Exception was thrown while processing event in session associated "
"with {}:{} with message: {}",
session_.socket().endpoint().address(),
session_.socket().endpoint().port(), e.what());
CloseSession();
}
void OnSessionAndTxTimeout() {
LOG(WARNING) << fmt::format(
"Session or transaction associated with {}:{} timed out.",
session_.socket().endpoint().address(),
session_.socket().endpoint().port());
// TODO: report to client what happend.
CloseSession();
}
void OnClose() {
LOG(INFO) << fmt::format("Client {}:{} closed the connection.",
session_.socket().endpoint().address(),
session_.socket().endpoint().port());
CloseSession();
}
private:
void CloseSession() {
session_.Close();
std::unique_lock<SpinLock> guard(worker_.lock_);
auto it = std::find_if(
worker_.session_listeners_.begin(), worker_.session_listeners_.end(),
[&](const auto &l) { return l->session_.Id() == session_.Id(); });
CHECK(it != worker_.session_listeners_.end())
<< "Trying to remove session that is not found in worker's sessions";
int i = it - worker_.session_listeners_.begin();
swap(worker_.session_listeners_[i], worker_.session_listeners_.back());
worker_.session_listeners_.pop_back();
}
TSession session_;
Worker &worker_;
};
SpinLock lock_;
TSessionData &session_data_;
std::vector<std::unique_ptr<SessionSocketListener>> session_listeners_;
io::network::SocketEventDispatcher<SessionSocketListener> dispatcher_{session_listeners_};
};
} // namespace communication

View File

@ -21,31 +21,78 @@ class Epoll {
public:
using Event = struct epoll_event;
explicit Epoll(int flags) : epoll_fd_(epoll_create1(flags)) {
Epoll(bool set_cloexec = false)
: epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) {
// epoll_create1 returns an error if there is a logical error in our code
// (for example invalid flags) or if there is irrecoverable error. In both
// cases it is best to terminate.
CHECK(epoll_fd_ != -1) << "Error on epoll_create1, errno: " << errno
<< ", message: " << strerror(errno);
CHECK(epoll_fd_ != -1) << "Error on epoll create: (" << errno << ") "
<< strerror(errno);
}
void Add(int fd, Event *event) {
auto status = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, event);
/**
* This function adds/modifies a file descriptor to be listened for events.
*
* @param fd file descriptor to add to epoll
* @param events epoll events mask
* @param ptr pointer to the associated event handler
* @param modify modify an existing file descriptor
*/
void Add(int fd, uint32_t events, void *ptr, bool modify = false) {
Event event;
event.events = events;
event.data.ptr = ptr;
int status = epoll_ctl(epoll_fd_, (modify ? EPOLL_CTL_MOD : EPOLL_CTL_ADD),
fd, &event);
// epoll_ctl can return an error on our logical error or on irrecoverable
// error. There is a third possibility that some system limit is reached. In
// that case we could return an erorr and close connection. Chances of
// reaching system limit in normally working memgraph is extremely unlikely,
// so it is correct to terminate even in that case.
CHECK(!status) << "Error on epoll_ctl, errno: " << errno
<< ", message: " << strerror(errno);
CHECK(!status) << "Error on epoll " << (modify ? "modify" : "add") << ": ("
<< errno << ") " << strerror(errno);
}
/**
* This function modifies a file descriptor that is listened for events.
*
* @param fd file descriptor to modify in epoll
* @param events epoll events mask
* @param ptr pointer to the associated event handler
*/
void Modify(int fd, uint32_t events, void *ptr) {
Add(fd, events, ptr, true);
}
/**
* This function deletes a file descriptor that is listened for events.
*
* @param fd file descriptor to delete from epoll
*/
void Delete(int fd) {
int status = epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, NULL);
// epoll_ctl can return an error on our logical error or on irrecoverable
// error. There is a third possibility that some system limit is reached. In
// that case we could return an erorr and close connection. Chances of
// reaching system limit in normally working memgraph is extremely unlikely,
// so it is correct to terminate even in that case.
CHECK(!status) << "Error on epoll delete: (" << errno << ") "
<< strerror(errno);
}
/**
* This function waits for events from epoll.
* It can be called from multiple threads, but should be used with care in
* that case, see:
* https://stackoverflow.com/questions/7058737/is-epoll-thread-safe
*
* @param fd file descriptor to delete from epoll
*/
int Wait(Event *events, int max_events, int timeout) {
auto num_events = epoll_wait(epoll_fd_, events, max_events, timeout);
// If this check fails there was logical error in our code.
CHECK(num_events != -1 || errno == EINTR)
<< "Error on epoll_wait, errno: " << errno
<< ", message: " << strerror(errno);
<< "Error on epoll wait: (" << errno << ") " << strerror(errno);
// num_events can be -1 if errno was EINTR (epoll_wait interrupted by signal
// handler). We treat that as no events, so we return 0.
return num_events == -1 ? 0 : num_events;

View File

@ -193,8 +193,7 @@ std::experimental::optional<Socket> Socket::Accept() {
return Socket(sfd, endpoint);
}
bool Socket::Write(const uint8_t *data, size_t len, bool have_more,
const std::function<bool()> &keep_retrying) {
bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
// MSG_NOSIGNAL is here to disable raising a SIGPIPE signal when a
// connection dies mid-write, the socket will only return an EPIPE error.
int flags = MSG_NOSIGNAL | (have_more ? MSG_MORE : 0);
@ -205,12 +204,8 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more,
// Terminal error, return failure.
return false;
}
// TODO: This can still cause timed out session to continue for a very
// long time. For example if timeout on send is 1 second and after every
// sencond we succeed in writing only one byte that this function can
// block for len seconds. Change semantics of keep_retrying function so
// that this check can be done in while loop even if send succeeds.
if (!keep_retrying()) return false;
// Non-fatal error, retry.
continue;
} else {
len -= written;
data += written;
@ -219,13 +214,12 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more,
return true;
}
bool Socket::Write(const std::string &s, bool have_more,
const std::function<bool()> &keep_retrying) {
return Write(reinterpret_cast<const uint8_t *>(s.data()), s.size(), have_more,
keep_retrying);
bool Socket::Write(const std::string &s, bool have_more) {
return Write(reinterpret_cast<const uint8_t *>(s.data()), s.size(),
have_more);
}
int Socket::Read(void *buffer, size_t len) {
return read(socket_, buffer, len);
int Socket::Read(void *buffer, size_t len, bool nonblock) {
return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0);
}
} // namespace io::network

View File

@ -132,23 +132,13 @@ class Socket {
* @param len length of char* or uint8_t* data
* @param have_more set to true if you plan to send more data to allow the
* kernel to buffer the data instead of immediately sending it out
* @param keep_retrying while function executes to true socket will retry to
* write data if nonterminal error occurred on socket (EAGAIN, EWOULDBLOCK,
* EINTR)... useful if socket is in nonblocking mode or timeout is set on a
* socket. By default Write doesn't retry if any error occurrs.
*
* TODO: Logic for retrying can be in derived class or in a wrapper of this
* class, unfortunately from current return value we don't know what error
* occured nor how much data was written.
*
* @return write success status:
* true if write succeeded
* false if write failed
*/
bool Write(const uint8_t *data, size_t len, bool have_more = false,
const std::function<bool()> &keep_retrying = [] { return false; });
bool Write(const std::string &s, bool have_more = false,
const std::function<bool()> &keep_retrying = [] { return false; });
bool Write(const uint8_t *data, size_t len, bool have_more = false);
bool Write(const std::string &s, bool have_more = false);
/**
* Read data from the socket.
@ -156,17 +146,14 @@ class Socket {
*
* @param buffer pointer to the read buffer
* @param len length of the read buffer
* @param nonblock set to true if you want a non-blocking read
*
* @return read success status:
* > 0 if data was read, means number of read bytes
* == 0 if the client closed the connection
* < 0 if an error has occurred
*/
// TODO: Return type should be something like StatusOr<int> which would return
// number of read bytes if read succeeded and error code and error message
// otherwise (deduced from errno). We can implement that type easily on top of
// std::variant once c++17 becomes available in memgraph.
int Read(void *buffer, size_t len);
int Read(void *buffer, size_t len, bool nonblock = false);
private:
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}

View File

@ -1,95 +0,0 @@
#pragma once
#include <glog/logging.h>
#include "io/network/epoll.hpp"
#include "utils/crtp.hpp"
namespace io::network {
/**
* This class listens to events on an epoll object and calls
* callback functions to process them.
*/
template <class Listener>
class SocketEventDispatcher {
public:
explicit SocketEventDispatcher(
std::vector<std::unique_ptr<Listener>> &listeners, uint32_t flags = 0)
: epoll_(flags), listeners_(listeners) {}
void AddListener(int fd, Listener &listener, uint32_t events) {
// Add the listener associated to fd file descriptor to epoll.
epoll_event event;
event.events = events;
event.data.ptr = &listener;
epoll_.Add(fd, &event);
}
// Returns true if there was event before timeout.
bool WaitAndProcessEvents() {
// Waits for an event/multiple events and returns a maximum of max_events
// and stores them in the events array. It waits for wait_timeout
// milliseconds. If wait_timeout is achieved, returns 0.
const auto n = epoll_.Wait(events_, kMaxEvents, 200);
DLOG_IF(INFO, n > 0) << "number of events: " << n;
// Go through all events and process them in order.
for (int i = 0; i < n; ++i) {
auto &event = events_[i];
Listener *listener = reinterpret_cast<Listener *>(event.data.ptr);
// Check if the listener that we got in the epoll event still exists
// because it might have been deleted in a previous call.
auto it =
std::find_if(listeners_.begin(), listeners_.end(),
[&](const auto &l) { return l.get() == listener; });
// If the listener doesn't exist anymore just ignore the event.
if (it == listeners_.end()) continue;
// Even though it is possible for multiple events to be reported we handle
// only one of them. Since we use epoll in level triggered mode
// unprocessed events will be reported next time we call epoll_wait. This
// kind of processing events is safer since callbacks can destroy listener
// and calling next callback on listener object will segfault. More subtle
// bugs are also possible: one callback can handle multiple events
// (maybe even in a subtle implicit way) and then we don't want to call
// multiple callbacks since we are not sure if those are valid anymore.
try {
if (event.events & EPOLLIN) {
// We have some data waiting to be read.
listener->OnData();
continue;
}
if (event.events & EPOLLRDHUP) {
listener->OnClose();
continue;
}
// There was an error on the server side.
if (!(event.events & EPOLLIN) || event.events & (EPOLLHUP | EPOLLERR)) {
listener->OnError();
continue;
}
} catch (const std::exception &e) {
listener->OnException(e);
}
}
return n > 0;
}
private:
static const int kMaxEvents = 64;
// TODO: epoll is really ridiculous here. We don't plan to handle thousands of
// connections so ppoll would actually be better or (even plain nonblocking
// socket).
Epoll epoll_;
Epoll::Event events_[kMaxEvents];
std::vector<std::unique_ptr<Listener>> &listeners_;
};
} // namespace io::network

View File

@ -107,7 +107,7 @@ void MasterMain() {
database::Master db;
SessionData session_data{db};
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, FLAGS_num_workers);
session_data, true, FLAGS_num_workers);
// Handler for regular termination signals
auto shutdown = [&server] {
@ -135,7 +135,7 @@ void SingleNodeMain() {
database::SingleNode db;
SessionData session_data{db};
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, FLAGS_num_workers);
session_data, true, FLAGS_num_workers);
// Handler for regular termination signals
auto shutdown = [&server] {

View File

@ -30,7 +30,6 @@ class TestSession {
event_.data.ptr = this;
}
bool Alive() const { return socket_.IsOpen(); }
bool TimedOut() const { return false; }
int Id() const { return socket_.fd(); }
@ -53,11 +52,6 @@ class TestSession {
void Written(size_t len) { buffer_.Written(len); }
void Close() {
DLOG(INFO) << "Close session!";
this->socket_.Close();
}
Socket &socket() { return socket_; }
void RefreshLastEventTime(

View File

@ -31,8 +31,6 @@ class TestSession {
event_.data.ptr = this;
}
bool Alive() const { return socket_.IsOpen(); }
bool TimedOut() const { return false; }
int Id() const { return socket_.fd(); }
@ -43,8 +41,6 @@ class TestSession {
void Written(size_t len) { buffer_.Written(len); }
void Close() { this->socket_.Close(); }
Socket &socket() { return socket_; }
void RefreshLastEventTime(
@ -86,7 +82,7 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
TestData data;
int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3;
communication::Server<TestSession, TestData> server(endpoint, data, N);
communication::Server<TestSession, TestData> server(endpoint, data, false, N);
const auto &ep = server.endpoint();
// start clients

View File

@ -21,7 +21,7 @@ TEST(Network, Server) {
// initialize server
TestData session_data;
int N = (std::thread::hardware_concurrency() + 1) / 2;
ServerT server(endpoint, session_data, N);
ServerT server(endpoint, session_data, false, N);
const auto &ep = server.endpoint();
// start clients

View File

@ -19,11 +19,10 @@ TEST(Network, SessionLeak) {
// initialize listen socket
Endpoint endpoint(interface, 0);
std::cout << endpoint << std::endl;
// initialize server
TestData session_data;
ServerT server(endpoint, session_data, 2);
ServerT server(endpoint, session_data, false, 2);
// start clients
int N = 50;

View File

@ -20,9 +20,6 @@ class TestSocket {
TestSocket(TestSocket &&) = default;
TestSocket &operator=(TestSocket &&) = default;
void Close() { socket_ = -1; }
bool IsOpen() { return socket_ != -1; }
int id() const { return socket_; }
bool Write(const uint8_t *data, size_t len, bool have_more = false,

View File

@ -16,6 +16,7 @@
std::vector<uint8_t> &output = session.socket().output;
using communication::bolt::SessionData;
using communication::bolt::SessionException;
using communication::bolt::State;
using SessionT = communication::bolt::Session<TestSocket>;
using ResultStreamT = SessionT::ResultStreamT;
@ -89,7 +90,6 @@ void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) {
session.Execute();
ASSERT_EQ(session.state_, State::Init);
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
}
@ -109,7 +109,6 @@ void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len,
void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) {
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
CheckOutput(output, init_resp, 7);
}
@ -149,10 +148,9 @@ TEST(BoltSession, HandshakeWrongPreamble) {
// copy 0x00000001 four times
for (int i = 0; i < 4; ++i) memcpy(buff.data + i * 4, handshake_req + 4, 4);
session.Written(20);
session.Execute();
ASSERT_THROW(session.Execute(), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
PrintOutput(output);
CheckFailureMessage(output);
}
@ -166,14 +164,12 @@ TEST(BoltSession, HandshakeInTwoPackets) {
session.Execute();
ASSERT_EQ(session.state_, State::Handshake);
ASSERT_TRUE(session.socket().IsOpen());
memcpy(buff.data + 10, handshake_req + 10, 10);
session.Written(10);
session.Execute();
ASSERT_EQ(session.state_, State::Init);
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
}
@ -181,10 +177,11 @@ TEST(BoltSession, HandshakeInTwoPackets) {
TEST(BoltSession, HandshakeWriteFail) {
INIT_VARS;
session.socket().SetWriteSuccess(false);
ExecuteCommand(session, handshake_req, sizeof(handshake_req), false);
ASSERT_THROW(
ExecuteCommand(session, handshake_req, sizeof(handshake_req), false),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -196,10 +193,10 @@ TEST(BoltSession, HandshakeOK) {
TEST(BoltSession, InitWrongSignature) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteCommand(session, run_req_header, sizeof(run_req_header));
ASSERT_THROW(ExecuteCommand(session, run_req_header, sizeof(run_req_header)),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -209,10 +206,9 @@ TEST(BoltSession, InitWrongMarker) {
// wrong marker, good signature
uint8_t data[2] = {0x00, init_req[1]};
ExecuteCommand(session, data, 2);
ASSERT_THROW(ExecuteCommand(session, data, 2), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -224,10 +220,9 @@ TEST(BoltSession, InitMissingData) {
for (int i = 0; i < 3; ++i) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteCommand(session, init_req, len[i]);
ASSERT_THROW(ExecuteCommand(session, init_req, len[i]), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}
@ -236,10 +231,10 @@ TEST(BoltSession, InitWriteFail) {
INIT_VARS;
ExecuteHandshake(session, output);
session.socket().SetWriteSuccess(false);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -257,10 +252,9 @@ TEST(BoltSession, ExecuteRunWrongMarker) {
// wrong marker, good signature
uint8_t data[2] = {0x00, run_req_header[1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -273,10 +267,10 @@ TEST(BoltSession, ExecuteRunMissingData) {
INIT_VARS;
ExecuteHandshake(session, output);
ExecuteInit(session, output);
ExecuteCommand(session, run_req_header, len[i]);
ASSERT_THROW(ExecuteCommand(session, run_req_header, len[i]),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}
@ -291,15 +285,17 @@ TEST(BoltSession, ExecuteRunBasicException) {
session.socket().SetWriteSuccess(i == 0);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
if (i == 0) {
session.Execute();
} else {
ASSERT_THROW(session.Execute(), SessionException);
}
if (i == 0) {
ASSERT_EQ(session.state_, State::ErrorIdle);
ASSERT_TRUE(session.socket().IsOpen());
CheckFailureMessage(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -330,10 +326,9 @@ TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
// wrong marker, good signature
uint8_t data[2] = {0x00, dataset[i][1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}
@ -347,10 +342,10 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) {
ExecuteInit(session, output);
session.socket().SetWriteSuccess(i == 0);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_THROW(ExecuteCommand(session, pullall_req, sizeof(pullall_req)),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
if (i == 0) {
CheckFailureMessage(output);
} else {
@ -377,16 +372,18 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) {
if (j == 1) output.clear();
session.socket().SetWriteSuccess(j == 0);
ExecuteCommand(session, dataset[i], 2);
if (j == 0) {
ExecuteCommand(session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(session, dataset[i], 2), SessionException);
}
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket().IsOpen());
ASSERT_FALSE(session.encoder_buffer_.HasData());
PrintOutput(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -398,10 +395,10 @@ TEST(BoltSession, ExecuteInvalidMessage) {
ExecuteHandshake(session, output);
ExecuteInit(session, output);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -419,18 +416,21 @@ TEST(BoltSession, ErrorIgnoreMessage) {
output.clear();
session.socket().SetWriteSuccess(i == 0);
ExecuteCommand(session, init_req, sizeof(init_req));
if (i == 0) {
ExecuteCommand(session, init_req, sizeof(init_req));
} else {
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req)),
SessionException);
}
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (i == 0) {
ASSERT_EQ(session.state_, State::ErrorIdle);
ASSERT_TRUE(session.socket().IsOpen());
CheckOutput(output, ignored_resp, sizeof(ignored_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -455,10 +455,9 @@ TEST(BoltSession, ErrorRunAfterRun) {
// New run request.
WriteRunRequest(session, "MATCH (n) RETURN n");
session.Execute();
ASSERT_THROW(session.Execute(), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
}
TEST(BoltSession, ErrorCantCleanup) {
@ -473,10 +472,10 @@ TEST(BoltSession, ErrorCantCleanup) {
output.clear();
// there is data missing in the request, cleanup should fail
ExecuteCommand(session, init_req, sizeof(init_req) - 10);
ASSERT_THROW(ExecuteCommand(session, init_req, sizeof(init_req) - 10),
SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -493,10 +492,9 @@ TEST(BoltSession, ErrorWrongMarker) {
// wrong marker, good signature
uint8_t data[2] = {0x00, init_req[1]};
ExecuteCommand(session, data, sizeof(data));
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -518,18 +516,20 @@ TEST(BoltSession, ErrorOK) {
output.clear();
session.socket().SetWriteSuccess(j == 0);
ExecuteCommand(session, dataset[i], 2);
if (j == 0) {
ExecuteCommand(session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(session, dataset[i], 2), SessionException);
}
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket().IsOpen());
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -549,10 +549,9 @@ TEST(BoltSession, ErrorMissingData) {
// some marker, missing signature
uint8_t data[1] = {0x00};
ExecuteCommand(session, data, sizeof(data));
ASSERT_THROW(ExecuteCommand(session, data, sizeof(data)), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -566,7 +565,6 @@ TEST(BoltSession, MultipleChunksInOneExecute) {
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
// Count chunks in output
@ -598,15 +596,13 @@ TEST(BoltSession, PartialChunk) {
session.Execute();
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
WriteChunkTail(session);
session.Execute();
ASSERT_THROW(session.Execute(), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_GT(output.size(), 0);
PrintOutput(output);
}
@ -657,8 +653,6 @@ TEST(BoltSession, ExplicitTransactionValidQueries) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_);
CheckSuccessMessage(output);
ASSERT_TRUE(session.socket().IsOpen());
}
}
@ -702,25 +696,22 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) {
CheckSuccessMessage(output);
WriteRunRequest(session, transaction_end.c_str());
session.Execute();
if (transaction_end == "ROLLBACK") {
session.Execute();
ASSERT_EQ(session.state_, State::Result);
ASSERT_FALSE(session.db_accessor_);
ASSERT_TRUE(session.socket().IsOpen());
CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_);
ASSERT_TRUE(session.socket().IsOpen());
CheckSuccessMessage(output);
} else {
ASSERT_THROW(session.Execute(), SessionException);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.db_accessor_);
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}

View File

@ -31,7 +31,7 @@ class RunningServer {
database::SingleNode db_;
SessionData session_data_{db_};
Endpoint endpoint_{"127.0.0.1", 0};
ServerT server_{endpoint_, session_data_, 1};
ServerT server_{endpoint_, session_data_, true, 1};
};
class TestClient : public ClientT {
@ -48,7 +48,7 @@ class TestClient : public ClientT {
TEST(NetworkTimeouts, InactiveSession) {
FLAGS_query_execution_time_sec = 60;
FLAGS_session_inactivity_timeout = 1;
FLAGS_session_inactivity_timeout = 2;
RunningServer rs;
TestClient client(rs.server_.endpoint());
@ -68,12 +68,12 @@ TEST(NetworkTimeouts, InactiveSession) {
client.Execute("RETURN 1", {});
// After sleep, session should have timed out.
std::this_thread::sleep_for(1500ms);
std::this_thread::sleep_for(3500ms);
EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException);
}
TEST(NetworkTimeouts, TimeoutInMultiCommandTransaction) {
FLAGS_query_execution_time_sec = 1;
FLAGS_query_execution_time_sec = 2;
FLAGS_session_inactivity_timeout = 60;
RunningServer rs;
@ -88,7 +88,7 @@ TEST(NetworkTimeouts, TimeoutInMultiCommandTransaction) {
client.Execute("RETURN 1", {});
// Session shouldn't be alive anymore.
std::this_thread::sleep_for(2s);
std::this_thread::sleep_for(4s);
EXPECT_THROW(client.Execute("RETURN 1", {}), ClientException);
}